From d108bb81b42fc637b8a6ab91dd57ab2404e4d961 Mon Sep 17 00:00:00 2001 From: fumiama Date: Fri, 31 Dec 2021 12:34:54 +0800 Subject: [PATCH] =?UTF-8?q?add=20=E5=88=86=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gold/head/packet.go | 85 +++++++++++++-------- gold/link/link.go | 21 +++--- gold/link/listen.go | 15 ++-- gold/link/me.go | 24 +++++- gold/link/recv.go | 77 +++++++++++++++++++ lower/nic.go | 10 +-- main.go | 91 ++++------------------- upper/data.go | 11 ++- upper/services/tunnel/tunnel.go | 29 +++++--- upper/services/tunnel/tunnel_test.go | 13 +++- upper/services/wg/wg.go | 107 +++++++++++++++++++++++++++ 11 files changed, 336 insertions(+), 147 deletions(-) create mode 100644 gold/link/recv.go create mode 100644 upper/services/wg/wg.go diff --git a/gold/head/packet.go b/gold/head/packet.go index 69b1efc..420b66f 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -12,9 +12,11 @@ import ( // Packet 是发送和接收的最小单位 type Packet struct { + // Ver 协议版本 + Ver uint16 // DataSZ len(Data) // 不得超过 65507-head 字节 - DataSZ uint32 + DataSZ uint16 // Proto 详见 head Proto uint8 // TTL is time to live @@ -23,9 +25,11 @@ type Packet struct { SrcPort uint16 // DstPort 目的端口 DstPort uint16 - // Src 源 ip + // Flags 高3位为标志(xDM),低13位为分片偏移 + Flags uint16 + // Src 源 ip (ipv4) Src net.IP - // Dst 目的 ip + // Dst 目的 ip (ipv4) Dst net.IP // Hash 使用 BLAKE2 生成加密前 Packet 的摘要 // 生成时 Hash 全 0 @@ -33,12 +37,15 @@ type Packet struct { Hash [32]byte // Data 承载的数据 Data []byte + // 记录还有多少字节未到达 + rembytes uint16 } // NewPacket 生成一个新包 func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet { logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data) return &Packet{ + Ver: 1, Proto: proto, TTL: 16, SrcPort: srcPort, @@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b } // Unmarshal 将 data 的数据解码到自身 -func (p *Packet) Unmarshal(data []byte) error { +func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { if len(data) < 12 { - return errors.New("data len < 12") + err = errors.New("data len < 12") + return } - p.DataSZ = binary.LittleEndian.Uint32(data[:4]) - pt := binary.LittleEndian.Uint16(data[4:6]) - p.Proto = uint8(pt) - p.TTL = uint8(pt >> 8) - p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) - p.DstPort = binary.LittleEndian.Uint16(data[8:10]) - sdl := binary.LittleEndian.Uint16(data[10:12]) - srclen := uint8(sdl) - dstlen := uint8(sdl >> 8) - if len(data) < int(12+srclen+dstlen) { - return errors.New("data src or dst len mismatch") + if p.DataSZ == 0 && len(p.Data) == 0 { + p.Ver = binary.LittleEndian.Uint16(data[:2]) + if p.Ver != 1 { + err = errors.New("unknown protocol version") + return + } + p.DataSZ = binary.LittleEndian.Uint16(data[2:4]) + p.Data = make([]byte, p.DataSZ) + pt := binary.LittleEndian.Uint16(data[4:6]) + p.Proto = uint8(pt) + p.TTL = uint8(pt >> 8) + p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) + p.DstPort = binary.LittleEndian.Uint16(data[8:10]) + p.rembytes = p.DataSZ } - if srclen > 0 { - p.Src = make(net.IP, srclen) - copy(p.Src, data[12:12+srclen]) - } - if dstlen > 0 { - p.Dst = make(net.IP, dstlen) - copy(p.Dst, data[12+srclen:12+srclen+dstlen]) - } - copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32]) - p.Data = data[12+srclen+dstlen+32:] - return nil + + p.Flags = binary.LittleEndian.Uint16(data[10:12]) + + p.Src = make(net.IP, 4) + copy(p.Src, data[12:16]) + p.Dst = make(net.IP, 4) + copy(p.Dst, data[16:20]) + copy(p.Hash[:], data[20:52]) + p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:])) + + complete = p.rembytes == 0 + + return } // Marshal 将自身数据编码为 []byte -func (p *Packet) Marshal(src net.IP) []byte { +// offset 必须为 8 的倍数,表示偏移的 8 位 +func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte { p.TTL-- if p.TTL == 0 { return nil } - p.DataSZ = uint32(len(p.Data)) + p.DataSZ = uint16(len(p.Data)) if src != nil { p.Src = src + offset >>= 3 + if dontfrag { + offset |= 0x4000 + } + if hasmore { + offset |= 0x2000 + } + p.Flags = offset } packet := make([]byte, 52+len(p.Data)) - binary.LittleEndian.PutUint32(packet[:4], p.DataSZ) + binary.LittleEndian.PutUint16(packet[:2], p.Ver) + binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ) binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto)) binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort) binary.LittleEndian.PutUint16(packet[8:10], p.DstPort) - binary.LittleEndian.PutUint16(packet[10:12], 0x0404) + binary.LittleEndian.PutUint16(packet[10:12], p.Flags) copy(packet[12:16], p.Src.To4()) copy(packet[16:20], p.Dst.To4()) copy(packet[20:52], p.Hash[:]) diff --git a/gold/link/link.go b/gold/link/link.go index 63349d0..efc40c7 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet { // Write 向 peer 发包 func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) { - if len(p.Data) <= (32768 - 64) { - return l.write(p, istransfer) + if len(p.Data) <= int(l.me.mtu) { + return l.write(p, 0, istransfer, false) } data := p.Data offset := 0 - for len(data) > (32768 - 64) { + for len(data) > int(l.me.mtu) { packet := *p - packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)] - i, err := l.write(&packet, istransfer) + packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)] + i, err := l.write(&packet, uint16(offset), istransfer, true) n += i if err != nil { return n, err } - data = data[(offset+1)*(32768-64):] + data = data[(offset+1)*int(l.me.mtu):] } return n, nil } @@ -107,14 +107,17 @@ func (l *Link) String() (n string) { } // write 向 peer 发一个包 -func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) { +func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) { var d []byte if istransfer { - d = p.Marshal(nil) + if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) { + return len(p.Data), errors.New("drop dont fragmnet big trans packet") + } + d = p.Marshal(nil, 0, false, false) } else { p.FillHash() p.Data = l.Encode(p.Data) - d = p.Marshal(l.me.me) + d = p.Marshal(l.me.me, offset, false, hasmore) } if d == nil { return 0, errors.New("[link] ttl exceeded") diff --git a/gold/link/listen.go b/gold/link/listen.go index 61e0f90..279b7ed 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { n, addr, err := conn.ReadFromUDP(lbf) if err == nil { lbf = lbf[:n] - packet := head.Packet{} - err = packet.Unmarshal(lbf) - if err == nil { + packet := m.wait(lbf) + if packet != nil { r := int(packet.DataSZ) - len(packet.Data) if r > 0 { remain, err := readAll(conn, r) @@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { } case head.ProtoNotify: logrus.Infoln("[link] recv notify") - p.onNotify(&packet) + p.onNotify(packet) case head.ProtoQuery: logrus.Infoln("[link] recv query") - p.onQuery(&packet) + p.onQuery(packet) case head.ProtoData: if p.pipe != nil { - p.pipe <- &packet + p.pipe <- packet logrus.Infoln("[link] deliver to pipe of", p.peerip) } else { - m.pipe <- &packet + m.pipe <- packet logrus.Infoln("[link] deliver to pipe of me") } default: @@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { } else if p.Accept(packet.Dst) { if p.allowtrans { // 转发 - n, err = p.Write(&packet, true) + n, err = p.Write(packet, true) if err == nil { logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) } else { diff --git a/gold/link/me.go b/gold/link/me.go index bba8f60..ae780ed 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -32,10 +32,17 @@ type Me struct { pipe chan *head.Packet // 本机路由表 router *Router + // 本机未接收完全分片池 + recving map[[32]byte]*head.Packet + recvmu sync.Mutex + // 超时定时器 + clock map[*head.Packet]uint8 + // 本机上层配置 + srcport, dstport, mtu uint16 } // NewMe 设置本机参数 -func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) { +func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) { m.privKey = *privateKey var err error m.myend, err = net.ResolveUDPAddr("udp", myEndpoint) @@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei } m.router.SetDefault(nil) m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink) + m.srcport = srcport + m.dstport = dstport + m.mtu = mtu return } + +func (m *Me) SrcPort() uint16 { + return m.srcport +} + +func (m *Me) DstPort() uint16 { + return m.dstport +} + +func (m *Me) MTU() uint16 { + return m.mtu +} diff --git a/gold/link/recv.go b/gold/link/recv.go new file mode 100644 index 0000000..f1ac2e1 --- /dev/null +++ b/gold/link/recv.go @@ -0,0 +1,77 @@ +package link + +import ( + "encoding/binary" + "time" + "unsafe" + + "github.com/fumiama/WireGold/gold/head" + "github.com/sirupsen/logrus" +) + +func (m *Me) initrecvpool() { + if m.recving == nil { + m.recving = make(map[[32]byte]*head.Packet, 128) + } + // 超时定时器 + m.clock = make(map[*head.Packet]uint8, 128) + var delhs []*head.Packet + t := time.NewTicker(time.Second) + for range t.C { + m.recvmu.Lock() + for k, v := range m.clock { + if v > 10 { // 10s + delete(m.recving, k.Hash) + delhs = append(delhs, k) + } else { + m.clock[k]++ + } + } + for _, k := range delhs { + delete(m.clock, k) + logrus.Warnln("[recv] drop timeout packet from", k.Src) + } + delhs = delhs[:0] + m.recvmu.Unlock() + } +} + +func (m *Me) wait(data []byte) *head.Packet { + flags := binary.LittleEndian.Uint16(data[10:12]) + if flags == 0 || flags == 0x4000 { + h := &head.Packet{} + _, err := h.Unmarshal(data) + if err != nil { + logrus.Errorln("[recv] unmarshal err:", err) + return nil + } + return h + } + + m.recvmu.Lock() + defer m.recvmu.Unlock() + hashd := data[20:52] + hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd))) + h, ok := m.recving[hsh] + if ok { + ok, err := h.Unmarshal(data) + if err == nil { + if ok { + return h + } + m.clock[h] = 0 + } else { + logrus.Errorln("[recv] unmarshal err:", err) + } + return nil + } + h = &head.Packet{} + _, err := h.Unmarshal(data) + if err != nil { + logrus.Errorln("[recv] unmarshal err:", err) + return nil + } + m.recving[hsh] = h + m.clock[h] = 0 + return nil +} diff --git a/lower/nic.go b/lower/nic.go index bec191a..adbb9ef 100644 --- a/lower/nic.go +++ b/lower/nic.go @@ -58,8 +58,8 @@ func (nc *NIC) Start(m *link.Me) { logrus.Infoln("[lower] recv write", n, "bytes packet to nic") } }() - buf := make([]byte, 65536) // 永远不可能超界 - for nc.hasstart { // 从 NIC 发送 + buf := make([]byte, m.MTU()+64) // 增加报头长度与 TEA 冗余 + for nc.hasstart { // 从 NIC 发送 packet := buf n, err := nc.ifce.Read(packet) if err != nil { @@ -115,15 +115,13 @@ func send(m *link.Me, packet []byte) (n int, rem []byte) { packet = packet[:totl] n = int(totl) dst := waterutil.IPv4Destination(packet) - srcport := waterutil.IPv4SourcePort(packet) - dstport := waterutil.IPv4DestinationPort(packet) - logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(srcport)), "to", dst.String()+":"+strconv.Itoa(int(dstport))) + logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort()))) lnk, err := m.Connect(dst.String()) if err != nil { logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err) return } - _, err = lnk.Write(head.NewPacket(head.ProtoData, srcport, dst, dstport, packet), false) + _, err = lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), dst, m.DstPort(), packet), false) if err != nil { logrus.Warnln("[lower] write to peer", dst.String(), "err:", err) } diff --git a/main.go b/main.go index fda8de5..fc87857 100644 --- a/main.go +++ b/main.go @@ -1,22 +1,20 @@ package main import ( + "bytes" "flag" "fmt" - "net" "os" base14 "github.com/fumiama/go-base16384" curve "github.com/fumiama/go-x25519" "github.com/fumiama/WireGold/config" - "github.com/fumiama/WireGold/gold/link" "github.com/fumiama/WireGold/helper" - "github.com/fumiama/WireGold/lower" + "github.com/fumiama/WireGold/upper" + "github.com/fumiama/WireGold/upper/services/wg" ) -const suffix32 = "㴄" - func main() { help := flag.Bool("h", false, "display this help") gen := flag.Bool("g", false, "generate key pair") @@ -44,16 +42,11 @@ func main() { os.Exit(0) } if helper.IsNotExist(*file) { - f, err := os.Create(*file) - if err != nil { - panic(err) - } + f := new(bytes.Buffer) var r string fmt.Print("IP: ") fmt.Scanln(&r) if r == "" { - f.Close() - os.Remove(*file) fmt.Println("nil ip") return } @@ -63,8 +56,6 @@ func main() { fmt.Print("SubNet: ") fmt.Scanln(&r) if r == "" { - f.Close() - os.Remove(*file) fmt.Println("nil subnet") return } @@ -74,8 +65,6 @@ func main() { fmt.Print("PrivateKey: ") fmt.Scanln(&r) if r == "" { - f.Close() - os.Remove(*file) fmt.Println("nil private key") return } @@ -85,15 +74,18 @@ func main() { fmt.Print("EndPoint: ") fmt.Scanln(&r) if r == "" { - f.Close() - os.Remove(*file) fmt.Println("nil endpoint") return } f.WriteString("EndPoint: " + r + "\n") r = "" - f.Close() + cfgf, err := os.Create(*file) + if err != nil { + panic(err) + } + cfgf.Write(f.Bytes()) + cfgf.Close() } c := config.Parse(*file) if c.IP == "" { @@ -108,73 +100,18 @@ func main() { if c.EndPoint == "" { displayHelp("nil endpoint") } - var key [32]byte - k, err := base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32)) + w, err := wg.NewWireGold(&c) if err != nil { panic(err) } - n := copy(key[:], base14.Decode(k)) - if n != 32 { - displayHelp("private key length is not 32") - } if *showp { - c := curve.Get(key[:]) - pubk, err := base14.UTF16be2utf8(base14.Encode((*c.Public())[:])) - if err != nil { - panic(err) - } - fmt.Println("PublicKey:", helper.BytesToString(pubk[:57])) + fmt.Println("PublicKey:", w.PublicKey) os.Exit(0) } - cidrsmap := make(map[string]bool, 32) - _, mysubnet, err := net.ParseCIDR(c.SubNet) - if err != nil { - panic(err) - } - for _, p := range c.Peers { - for _, ip := range p.AllowedIPs { - ipnet, _, err := net.ParseCIDR(ip) - if err != nil { - panic(err) - } - if !mysubnet.Contains(ipnet) { - cidrsmap[ip] = true - } - } - } - cidrs := make([]string, len(cidrsmap)) - i := 0 - for k := range cidrsmap { - cidrs[i] = k - i++ - } - - nic := lower.NewNIC(c.IP, c.SubNet, cidrs...) - me := link.NewMe(&key, c.IP+"/32", c.EndPoint, true) - - for _, peer := range c.Peers { - var peerkey [32]byte - k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32)) - if err != nil { - panic(err) - } - n := copy(peerkey[:], base14.Decode(k)) - if n != 32 { - panic("peer public key length is not 32") - } - me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true) - } - - nic.Up() - defer func() { - nic.Stop() - nic.Down() - nic.Destroy() - }() - - nic.Start(&me) + defer w.Stop() + w.Run(upper.ServiceWireGold, upper.ServiceWireGold, 32768-64) } func displayHelp(hint string) { diff --git a/upper/data.go b/upper/data.go index 27656cd..6f6f20b 100644 --- a/upper/data.go +++ b/upper/data.go @@ -8,9 +8,16 @@ const ( ServiceNull = iota // ServiceTunnel 管道通信服务 ServiceTunnel + // ServiceWireGold 虚拟组网服务 + ServiceWireGold ) type Service interface { - Create(peer string, srcport, destport, mtu uint16) (Service, error) - io.ReadWriteCloser + // Start 无阻塞运行 + Start(srcport, destport, mtu uint16) + // Run 阻塞运行 + Run(srcport, destport, mtu uint16) + // Stop 停止 + Stop() + io.ReadWriter } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 1ee3f65..c81b84e 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -21,24 +21,36 @@ type Tunnel struct { mtu uint16 } -func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, err error) { - logrus.Infoln("[tunnel] create from", srcport, "to", destport) +func Create(me *link.Me, peer string) (s Tunnel, err error) { s.l, err = me.Connect(peer) if err == nil { s.in = make(chan []byte, 4) s.out = make(chan []byte, 4) s.peerip = net.ParseIP(peer) - s.src = srcport - s.dest = destport - s.mtu = mtu - go s.handleWrite() - go s.handleRead() } else { logrus.Errorln("[tunnel] create err:", err) } return } +func (s *Tunnel) Start(srcport, destport, mtu uint16) { + logrus.Infoln("[tunnel] start from", srcport, "to", destport) + s.src = srcport + s.dest = destport + s.mtu = mtu + go s.handleWrite() + go s.handleRead() +} + +func (s *Tunnel) Run(srcport, destport, mtu uint16) { + logrus.Infoln("[tunnel] start from", srcport, "to", destport) + s.src = srcport + s.dest = destport + s.mtu = mtu + go s.handleWrite() + s.handleRead() +} + func (s *Tunnel) Write(p []byte) (int, error) { s.in <- p return len(p), nil @@ -63,10 +75,9 @@ func (s *Tunnel) Read(p []byte) (int, error) { return 0, errors.New("reading reaches nil") } -func (s *Tunnel) Close() error { +func (s *Tunnel) Stop() { s.l.Close() close(s.in) - return nil } func (s *Tunnel) handleWrite() { diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index b7b2536..2fa3e4b 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -27,18 +27,20 @@ func TestTunnel(t *testing.T) { t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:])) t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:])) - m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false) + m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false, 1, 1, 4096) m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false) - p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false) + p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false, 1, 1, 4096) p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false) - tunnme, err := Create(&m, "192.168.1.3", 1, 1, 4096) + tunnme, err := Create(&m, "192.168.1.3") if err != nil { t.Fatal(err) } - tunnpeer, err := Create(&p, "192.168.1.2", 1, 1, 4096) + tunnme.Start(1, 1, 4096) + tunnpeer, err := Create(&p, "192.168.1.2") if err != nil { t.Fatal(err) } + tunnpeer.Start(1, 1, 4096) sendb := ([]byte)("1234") tunnme.Write(sendb) @@ -68,4 +70,7 @@ func TestTunnel(t *testing.T) { if string(sendb) != string(buf) { t.Fatal("error: recv 131072 bytes data") } + + tunnme.Stop() + tunnpeer.Stop() } diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go new file mode 100644 index 0000000..4bd0b80 --- /dev/null +++ b/upper/services/wg/wg.go @@ -0,0 +1,107 @@ +package wg + +import ( + "errors" + "net" + + base14 "github.com/fumiama/go-base16384" + curve "github.com/fumiama/go-x25519" + + "github.com/fumiama/WireGold/config" + "github.com/fumiama/WireGold/gold/link" + "github.com/fumiama/WireGold/helper" + "github.com/fumiama/WireGold/lower" +) + +const suffix32 = "㴄" + +type WG struct { + c *config.Config + key [32]byte + PublicKey string + nic *lower.NIC + me link.Me +} + +func NewWireGold(c *config.Config) (wg WG, err error) { + wg.c = c + + var k []byte + k, err = base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32)) + if err != nil { + return + } + n := copy(wg.key[:], base14.Decode(k)) + if n != 32 { + err = errors.New("private key length is not 32") + return + } + + cur := curve.Get(wg.key[:]) + pubk, err := base14.UTF16be2utf8(base14.Encode((*cur.Public())[:])) + if err != nil { + return + } + wg.PublicKey = helper.BytesToString(pubk[:57]) + + return +} + +func (wg *WG) Start(srcport, destport, mtu uint16) { + wg.init(srcport, destport, mtu) + wg.nic.Up() + go wg.nic.Start(&wg.me) +} + +func (wg *WG) Run(srcport, destport, mtu uint16) { + wg.init(srcport, destport, mtu) + wg.nic.Up() + wg.nic.Start(&wg.me) +} + +func (wg *WG) Stop() { + wg.nic.Stop() + wg.nic.Down() + wg.nic.Destroy() +} + +func (wg *WG) init(srcport, destport, mtu uint16) { + cidrsmap := make(map[string]bool, 32) + _, mysubnet, err := net.ParseCIDR(wg.c.SubNet) + if err != nil { + panic(err) + } + for _, p := range wg.c.Peers { + for _, ip := range p.AllowedIPs { + ipnet, _, err := net.ParseCIDR(ip) + if err != nil { + panic(err) + } + if !mysubnet.Contains(ipnet) { + cidrsmap[ip] = true + } + } + } + cidrs := make([]string, len(cidrsmap)) + i := 0 + for k := range cidrsmap { + cidrs[i] = k + i++ + } + + wg.nic = lower.NewNIC(wg.c.IP, wg.c.SubNet, cidrs...) + wg.me = link.NewMe(&wg.key, wg.c.IP+"/32", wg.c.EndPoint, true, srcport, destport, mtu) + + for _, peer := range wg.c.Peers { + var peerkey [32]byte + k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32)) + if err != nil { + panic(err) + } + n := copy(peerkey[:], base14.Decode(k)) + if n != 32 { + panic("peer public key length is not 32") + } + wg.me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true) + } +}