diff --git a/go.sum b/go.sum index d4e2b08..408f1f8 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/fumiama/go-x25519 v1.0.0 h1:hiGg9EhseVmGCc8T1jECVkj8Keu/aJ1ZK05RM8Vua github.com/fumiama/go-x25519 v1.0.0/go.mod h1:8VOhfyGZzw4IUs4nCjQFqW9cA3V/QpSCtP3fo2dLNg4= github.com/fumiama/gofastTEA v0.0.6 h1:Yni3MXDbJVa/c4CecgdZDgCJK+fLdvGph+OBqY2mtiI= github.com/fumiama/gofastTEA v0.0.6/go.mod h1:+sBZ05nCA2skZkursHNvyr8kULlEetrYTM2y5kA4rQc= -github.com/fumiama/water v0.0.0-20211229155341-82646596a427 h1:6T/Y1o2wrNzJKjhjOsCVkKJwIlU8jfUBfRYFYQ9r3Uc= -github.com/fumiama/water v0.0.0-20211229155341-82646596a427/go.mod h1:BBnNY9PwK+UUn4trAU+H0qsMEypm7+3Bj1bVFuJItlo= github.com/fumiama/water v0.0.0-20211230051437-3d685121087a h1:jVm9uuodbGiBqJzPEHSpYjdFZQ+B9SwBrpXCffr90GY= github.com/fumiama/water v0.0.0-20211230051437-3d685121087a/go.mod h1:BBnNY9PwK+UUn4trAU+H0qsMEypm7+3Bj1bVFuJItlo= github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0 h1:WfrSFlIlCAtg6Rt2IGna0HhJYSDE45YVHiYqO4wwsEw= diff --git a/gold/head/packet.go b/gold/head/packet.go index d67dc2a..5e5fed4 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -1,10 +1,13 @@ package head import ( - "encoding/json" + "encoding/binary" + "errors" + "net" "unsafe" blake2b "github.com/minio/blake2b-simd" + "github.com/sirupsen/logrus" ) // Packet 是发送和接收的最小单位 @@ -21,9 +24,9 @@ type Packet struct { // DstPort 目的端口 DstPort uint16 // Src 源 ip - Src string + Src net.IP // Dst 目的 ip - Dst string + Dst net.IP // Hash 使用 BLAKE2 生成加密前 Packet 的摘要 // 生成时 Hash 全 0 // https://github.com/minio/blake2b-simd @@ -33,27 +36,72 @@ type Packet struct { } // NewPacket 生成一个新包 -func NewPacket(proto uint8, srcPort uint16, dstPort uint16, data []byte) *Packet { +func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet { + logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data) return &Packet{ Proto: proto, - TTL: 255, + TTL: 16, SrcPort: srcPort, DstPort: dstPort, + Dst: dst, Data: data, } } // Unmarshal 将 data 的数据解码到自身 func (p *Packet) Unmarshal(data []byte) error { - return json.Unmarshal(data, p) + if len(data) < 12 { + return errors.New("data len < 12") + } + p.DataSZ = binary.LittleEndian.Uint32(data[:4]) + pt := binary.LittleEndian.Uint16(data[4:6]) + p.Proto = uint8(pt) + p.TTL = uint8(pt >> 8) + p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) + p.DstPort = binary.LittleEndian.Uint16(data[8:10]) + sdl := binary.LittleEndian.Uint16(data[10:12]) + srclen := uint8(sdl) + dstlen := uint8(sdl >> 8) + if len(data) < int(12+srclen+dstlen) { + return errors.New("data src or dst len mismatch") + } + if srclen > 0 { + p.Src = make(net.IP, srclen) + copy(p.Src, data[12:12+srclen]) + } + if dstlen > 0 { + p.Dst = make(net.IP, dstlen) + copy(p.Dst, data[12+srclen:12+srclen+dstlen]) + } + copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32]) + p.Data = data[12+srclen+dstlen+32:] + return nil } // Marshal 将自身数据编码为 []byte -func (p *Packet) Marshal(src string, dst string) ([]byte, error) { +func (p *Packet) Marshal(src net.IP) []byte { + p.TTL-- + if p.TTL == 0 { + return nil + } + p.DataSZ = uint32(len(p.Data)) p.Src = src - p.Dst = dst - return json.Marshal(p) + + packet := make([]byte, 52+len(p.Data)) + binary.LittleEndian.PutUint32(packet[:4], p.DataSZ) + binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto)) + binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort) + binary.LittleEndian.PutUint16(packet[8:10], p.DstPort) + binary.LittleEndian.PutUint16(packet[10:12], 0x0404) + copy(packet[12:16], p.Src.To4()) + copy(packet[16:20], p.Dst.To4()) + copy(packet[20:52], p.Hash[:]) + copy(packet[52:], p.Data) + + // logrus.Debugln("[packet] marshaled packet:", hex.EncodeToString(packet)) + + return packet } // FillHash 生成 p.Data 的 Hash diff --git a/gold/link/link.go b/gold/link/link.go index e380ce9..2a3870d 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -75,9 +75,11 @@ func (l *Link) Read() *head.Packet { func (l *Link) Write(p *head.Packet) (n int, err error) { p.FillHash() p.Data = l.Encode(p.Data) - var d []byte - d, err = p.Marshal(l.me.me.String(), l.peerip.String()) - logrus.Debugln("[link] write data", string(d)) + d := p.Marshal(l.me.me) + if d == nil { + return 0, errors.New("ttl exceeded") + } + logrus.Debugln("[link] write", len(d), "bytes data") if err == nil { peerlink := l.me.router.NextHop(l.peerip.String() + "/32") if peerlink != nil { diff --git a/gold/link/listen.go b/gold/link/listen.go index 1e8b5d8..6e49e36 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -2,6 +2,7 @@ package link import ( "net" + "strconv" "github.com/sirupsen/logrus" @@ -29,23 +30,23 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { packet.Data = append(packet.Data, remain...) } } - p, ok := m.IsInPeer(packet.Src) + p, ok := m.IsInPeer(packet.Src.String()) logrus.Infoln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) - logrus.Debugln("[link] recv:", string(lbf)) + // logrus.Debugln("[link] recv:", hex.EncodeToString(lbf)) if ok { if p.pep == "" || p.pep != addr.String() { logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String()) p.endpoint = addr p.pep = addr.String() } - if p.IsToMe(net.ParseIP(packet.Dst)) { + if p.IsToMe(packet.Dst) { packet.Data = p.Decode(packet.Data) if packet.IsVaildHash() { switch packet.Proto { case head.ProtoHello: switch p.status { case LINK_STATUS_DOWN: - _, _ = p.Write(head.NewPacket(head.ProtoHello, 0, 0, nil)) + _, _ = p.Write(head.NewPacket(head.ProtoHello, 0, p.peerip, 0, nil)) logrus.Infoln("[link] send hello ack packet") p.status = LINK_STATUS_HALFUP case LINK_STATUS_HALFUP: @@ -73,13 +74,13 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { } else { logrus.Infoln("[link] drop invalid packet") } - } else if p.Accept(net.ParseIP(packet.Dst)) && p.allowtrans { + } else if p.Accept(packet.Dst) && p.allowtrans { // 转发 p.Write(&packet) - logrus.Infoln("[link] trans") + logrus.Infoln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) } } else { - logrus.Infoln("[link] packet to", packet.Dst, "is refused") + logrus.Warnln("[link] packet to", packet.Dst, "is refused") } } } diff --git a/gold/link/nat.go b/gold/link/nat.go index 8058d59..29dadaf 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -15,7 +15,7 @@ func (l *Link) keepAlive() { go func() { t := time.NewTicker(time.Second * time.Duration(l.keepalive)) for range t.C { - _, _ = l.Write(head.NewPacket(head.ProtoHello, 0, 0, nil)) + _, _ = l.Write(head.NewPacket(head.ProtoHello, 0, l.peerip, 0, nil)) logrus.Infoln("[link.nat] send keep alive packet") } }() diff --git a/gold/link/query.go b/gold/link/query.go index 1d08a6c..31fde36 100644 --- a/gold/link/query.go +++ b/gold/link/query.go @@ -27,6 +27,6 @@ func (l *Link) SendQuery(peers ...string) error { if err != nil { return err } - _, err = l.Write(head.NewPacket(head.ProtoQuery, 0, 0, data)) + _, err = l.Write(head.NewPacket(head.ProtoQuery, 0, l.peerip, 0, data)) return err } diff --git a/lower/nic.go b/lower/nic.go index b9d4896..2cb5111 100644 --- a/lower/nic.go +++ b/lower/nic.go @@ -79,7 +79,7 @@ func (nc *NIC) Start(m *link.Me) { logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err) continue } - lnk.Write(head.NewPacket(head.ProtoData, srcport, dstport, packet)) + lnk.Write(head.NewPacket(head.ProtoData, srcport, dst, dstport, packet)) } } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 6232e09..71987ed 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -2,6 +2,7 @@ package tunnel import ( "errors" + "net" "github.com/sirupsen/logrus" @@ -14,6 +15,7 @@ type Tunnel struct { in chan []byte out chan []byte outcache []byte + peerip net.IP src uint16 dest uint16 mtu uint16 @@ -25,6 +27,7 @@ func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, if err == nil { s.in = make(chan []byte, 4) s.out = make(chan []byte, 4) + s.peerip = net.ParseIP(peer) s.src = srcport s.dest = destport s.mtu = mtu @@ -76,7 +79,7 @@ func (s *Tunnel) handleWrite() { logrus.Debugln("[tunnel] writing", len(b), "bytes...") for len(b) > int(s.mtu) { logrus.Infoln("[tunnel] split buffer") - _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.dest, b[:s.mtu])) + _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu])) if err != nil { logrus.Errorln("[tunnel] write err:", err) return @@ -84,7 +87,7 @@ func (s *Tunnel) handleWrite() { logrus.Debugln("[tunnel] write succeeded") b = b[s.mtu:] } - _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.dest, b)) + _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b)) if err != nil { logrus.Errorln("[tunnel] write err:", err) break