diff --git a/go.mod b/go.mod index ccb2c6b..b91f870 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,6 @@ go 1.16 require ( github.com/fumiama/go-x25519 v1.0.0 github.com/fumiama/gofastTEA v0.0.6 + github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect github.com/sirupsen/logrus v1.8.1 ) diff --git a/go.sum b/go.sum index c101dc2..c8ecc99 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/fumiama/gofastTEA v0.0.6/go.mod h1:+sBZ05nCA2skZkursHNvyr8kULlEetrYTM github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 h1:lYpkrQH5ajf0OXOcUbGjvZxxijuBwbbmlSxLiuofa+g= +github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1/go.mod h1:pD8RvIylQ358TN4wwqatJ8rNavkEINozVn9DtGI3dfQ= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/gold/head/packet.go b/gold/head/packet.go index eb729bc..2f932f3 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -1,13 +1,16 @@ package head import ( - "crypto/rand" "encoding/json" + "unsafe" + + blake2b "github.com/minio/blake2b-simd" ) // Packet 是发送和接收的最小单位 type Packet struct { // DataSZ len(Data) + // 不得超过 65507-head 字节 DataSZ uint32 // Proto 详见 head Proto uint8 @@ -47,11 +50,20 @@ func (p *Packet) UnMashal(data []byte) error { } // Mashal 将自身数据编码为 []byte -// 同时生成 Hash func (p *Packet) Mashal(src string, dst string) ([]byte, error) { p.DataSZ = uint32(len(p.Data)) p.Src = src p.Dst = dst - rand.Reader.Read(p.Hash[:]) return json.Marshal(p) } + +// FillHash 生成 p.Data 的 Hash +func (p *Packet) FillHash() { + sum := blake2b.New256().Sum(p.Data) + p.Hash = *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&sum))) +} + +func (p *Packet) IsVaildHash() bool { + sum := blake2b.New256().Sum(p.Data) + return *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&sum))) == p.Hash +} diff --git a/gold/link/crypto.go b/gold/link/crypto.go index ac3f11f..c490fae 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -50,7 +50,7 @@ func NewMe(privateKey *[32]byte, myIP string, myEndpoint string) (m Me) { } // Encode 使用 TEA 加密 -func (l *Link) Encode(b []byte) (eb []byte, err error) { +func (l *Link) Encode(b []byte) (eb []byte) { if b == nil { return } @@ -65,7 +65,7 @@ func (l *Link) Encode(b []byte) (eb []byte, err error) { } // Decode 使用 TEA 解密 -func (l *Link) Decode(b []byte) (db []byte, err error) { +func (l *Link) Decode(b []byte) (db []byte) { if b == nil { return } diff --git a/gold/link/link.go b/gold/link/link.go index afc1cf7..1ec9c78 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -69,14 +69,13 @@ func (l *Link) Read() *head.Packet { // Write 向 peer 发包 func (l *Link) Write(p *head.Packet) (n int, err error) { - p.Data, err = l.Encode(p.Data) + p.FillHash() + p.Data = l.Encode(p.Data) + var d []byte + d, err = p.Mashal(l.me.me.String(), l.peerip.String()) + logrus.Debugln("[link] write data", string(d)) if err == nil { - var d []byte - d, err = p.Mashal(l.me.me.String(), l.peerip.String()) - logrus.Debugln("[link] write data", string(d)) - if err == nil { - n, err = l.me.myconn.WriteToUDP(d, l.NextHop(l.peerip).endpoint) - } + n, err = l.me.myconn.WriteToUDP(d, l.NextHop(l.peerip).endpoint) } return } diff --git a/gold/link/listen.go b/gold/link/listen.go index a59306b..b460015 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -3,8 +3,9 @@ package link import ( "net" - "github.com/fumiama/WireGold/gold/head" "github.com/sirupsen/logrus" + + "github.com/fumiama/WireGold/gold/head" ) // 监听本机 endpoint @@ -38,8 +39,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { } if ok { if p.IsToMe(net.ParseIP(packet.Dst)) { - packet.Data, err = p.Decode(packet.Data) - if err == nil { + packet.Data = p.Decode(packet.Data) + if packet.IsVaildHash() { switch packet.Proto { case head.ProtoHello: switch p.status { @@ -64,6 +65,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { default: break } + } else { + logrus.Infoln("[link] drop invalid packet") } } else if p.Accept(net.ParseIP(packet.Dst)) && p.allowtrans { // 转发 diff --git a/gold/link/peer.go b/gold/link/peer.go index e89b72e..cedd117 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -1,7 +1,6 @@ package link import ( - "fmt" "net" "unsafe" @@ -30,7 +29,6 @@ func (m *Me) AddPeer(peerip string, pubicKey *[32]byte, endPoint string, allowed c := curve.Get(m.privKey[:]) k, err := c.Shared(pubicKey) if err == nil { - fmt.Println(len(k)) l.key = (*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&k))) } } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 9cb9fd2..6232e09 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -16,9 +16,10 @@ type Tunnel struct { outcache []byte src uint16 dest uint16 + mtu uint16 } -func Create(me *link.Me, peer string, srcport uint16, destport uint16) (s Tunnel, err error) { +func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, err error) { logrus.Infoln("[tunnel] create from", srcport, "to", destport) s.l, err = me.Connect(peer) if err == nil { @@ -26,6 +27,7 @@ func Create(me *link.Me, peer string, srcport uint16, destport uint16) (s Tunnel s.out = make(chan []byte, 4) s.src = srcport s.dest = destport + s.mtu = mtu go s.handleWrite() go s.handleRead() } else { @@ -72,13 +74,22 @@ func (s *Tunnel) handleWrite() { break } logrus.Debugln("[tunnel] writing", len(b), "bytes...") + for len(b) > int(s.mtu) { + logrus.Infoln("[tunnel] split buffer") + _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.dest, b[:s.mtu])) + if err != nil { + logrus.Errorln("[tunnel] write err:", err) + return + } + logrus.Debugln("[tunnel] write succeeded") + b = b[s.mtu:] + } _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.dest, b)) if err != nil { logrus.Errorln("[tunnel] write err:", err) break - } else { - logrus.Debugln("[tunnel] write succeeded") } + logrus.Debugln("[tunnel] write succeeded") } } diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index de80e11..e600b23 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -1,6 +1,8 @@ package tunnel import ( + "crypto/rand" + "encoding/hex" "testing" curve "github.com/fumiama/go-x25519" @@ -11,6 +13,7 @@ import ( func TestTunnel(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) + selfpk, err := curve.New(nil) if err != nil { panic(err) @@ -19,18 +22,24 @@ func TestTunnel(t *testing.T) { if err != nil { panic(err) } + t.Log("my priv key:", hex.EncodeToString(selfpk.Private()[:])) + t.Log("my publ key:", hex.EncodeToString(selfpk.Public()[:])) + t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:])) + t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:])) + m := link.NewMe(selfpk.Private(), "192.168.1.2", "127.0.0.1:1236") m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", nil, 0, false) p := link.NewMe(peerpk.Private(), "192.168.1.3", "127.0.0.1:1237") p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", nil, 0, false) - tunnme, err := Create(&m, "192.168.1.3", 1, 1) + tunnme, err := Create(&m, "192.168.1.3", 1, 1, 4096) if err != nil { t.Fatal(err) } - tunnpeer, err := Create(&p, "192.168.1.2", 1, 1) + tunnpeer, err := Create(&p, "192.168.1.2", 1, 1, 4096) if err != nil { t.Fatal(err) } + sendb := ([]byte)("1234") tunnme.Write(sendb) buf := make([]byte, 4) @@ -39,4 +48,24 @@ func TestTunnel(t *testing.T) { t.Log("error: recv", buf) t.Fail() } + + sendb = make([]byte, 4096) + rand.Read(sendb) + tunnme.Write(sendb) + buf = make([]byte, 4096) + tunnpeer.Read(buf) + if string(sendb) != string(buf) { + t.Fatal("error: recv 4096 bytes data") + } + + sendb = make([]byte, 131072) + rand.Read(sendb) + tunnme.Write(sendb) + buf = make([]byte, 131072) + for i := 0; i < 32; i++ { + tunnpeer.Read(buf[i*4096:]) + } + if string(sendb) != string(buf) { + t.Fatal("error: recv 131072 bytes data") + } }