diff --git a/gold/head/packet.go b/gold/head/packet.go index 17cd7cd..7557e29 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -12,6 +12,8 @@ import ( "github.com/sirupsen/logrus" ) +const PacketHeadLen = 60 + var ( ErrBadCRCChecksum = errors.New("bad crc checksum") ErrDataLenLT60 = errors.New("data len < 60") @@ -105,7 +107,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { err = ErrDataLenLT60 return } - p.crc64 = binary.LittleEndian.Uint64(data[52:60]) + p.crc64 = binary.LittleEndian.Uint64(data[52:PacketHeadLen]) if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 { err = ErrBadCRCChecksum return @@ -144,7 +146,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { } if p.rembytes > 0 { - p.rembytes -= copy(p.data[flags.Offset():], data[60:]) + p.rembytes -= copy(p.data[flags.Offset():], data[PacketHeadLen:]) logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) } @@ -162,18 +164,19 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui } if src != nil { - p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff p.Src = src - offset &= 0x1fff - if dontfrag { - offset |= 0x4000 - } - if hasmore { - offset |= 0x2000 - } - p.Flags = PacketFlags(offset) + p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff } + offset &= 0x1fff + if dontfrag { + offset |= 0x4000 + } + if hasmore { + offset |= 0x2000 + } + p.Flags = PacketFlags(offset) + return helper.OpenWriterF(func(w *helper.Writer) { w.WriteUInt32(p.idxdatsz) w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto)) diff --git a/gold/link/listen.go b/gold/link/listen.go index 43e8f37..225a415 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -125,7 +125,7 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish var err error data, err := p.Decode(packet.CipherIndex(), addt, packet.Body()) if err != nil { - logrus.Debugln("[listen] @", index, "drop invalid packet", ", key idx:", packet.CipherIndex(), "addt:", addt, "err:", err) + logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", packet.CipherIndex(), "addt:", addt, "err:", err) packet.Put() return } diff --git a/gold/link/recv.go b/gold/link/recv.go index 591a0ab..c19ee26 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -16,7 +16,7 @@ func (l *Link) Read() *head.Packet { } func (m *Me) wait(data []byte) *head.Packet { - if len(data) < 60 { // not a valid packet + if len(data) < head.PacketHeadLen { // not a valid packet return nil } bound := 64 @@ -33,7 +33,7 @@ func (m *Me) wait(data []byte) *head.Packet { logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) return nil } - crc := binary.LittleEndian.Uint64(data[52:60]) + crc := binary.LittleEndian.Uint64(data[52:head.PacketHeadLen]) if m.recved.Get(crc) { // 是重放攻击 logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) return nil diff --git a/gold/link/send.go b/gold/link/send.go index 3215e00..9b33f09 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -15,8 +15,8 @@ import ( ) var ( - ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet") - ErrTTL = errors.New("ttl exceeded") + ErrDropBigDontFragPkt = errors.New("drop big don't fragmnet packet") + ErrTTL = errors.New("ttl exceeded") ) // WriteAndPut 向 peer 发包并将包放回缓存池 @@ -32,7 +32,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { if !istransfer { l.encrypt(p, sndcnt, teatype) } - delta := (int(mtu) - 60) & 0x0000fff8 + delta := (int(mtu) - head.PacketHeadLen) & 0x0000fff8 if delta <= 0 { logrus.Warnln("[send] reset invalid data frag len", delta, "to 8") delta = 8 @@ -42,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) } if istransfer && p.Flags.DontFrag() && remlen > delta { - return 0, ErrDropBigDontFragTransPkt + return 0, ErrDropBigDontFragPkt } ttl := p.TTL totl := uint32(remlen) @@ -89,21 +89,25 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) { } // write 向 peer 发一个包 -func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) { +func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) { + peerep := l.endpoint + if peerep == nil { + return 0, errors.New("nil endpoint of " + p.Dst.String()) + } + var d []byte var cl func() + // TODO: now all packet allow frag, adapt to DF if istransfer { - d, cl = p.Marshal(nil, teatype, additional, 0, 0, false, false) + d, cl = p.Marshal(nil, 0, 0, 0, offset, false, hasmore) } else { d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore) } if d == nil { return 0, ErrTTL } - peerep := l.endpoint - if peerep == nil { - return 0, errors.New("nil endpoint of " + p.Dst.String()) - } + defer cl() + bound := 64 endl := "..." if len(d) < bound { @@ -114,7 +118,5 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) d = l.me.xorenc(d) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) - n, err = l.me.conn.WriteToPeer(d, peerep) - cl() - return + return l.me.conn.WriteToPeer(d, peerep) } diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index d381b52..e3d4802 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -16,7 +16,7 @@ import ( "github.com/fumiama/WireGold/helper" ) -func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { +func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint16) { selfpk, err := curve.New(nil) if err != nil { panic(err) @@ -37,7 +37,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { PrivateKey: selfpk.Private(), SrcPort: 1, DstPort: 1, - MTU: 4096, + MTU: mtu, }) defer m.Close() @@ -48,7 +48,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { PrivateKey: peerpk.Private(), SrcPort: 1, DstPort: 1, - MTU: 4096, + MTU: mtu, }) defer p.Close() @@ -65,8 +65,8 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { AllowedIPs: []string{"192.168.1.3/32"}, PubicKey: ppp, PresharedKey: pshk, - MTU: 4096, - MTURandomRange: 1024, + MTU: mtu, + MTURandomRange: mtu / 2, UseZstd: true, }) p.AddPeer(&link.PeerConfig{ @@ -75,8 +75,8 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { AllowedIPs: []string{"192.168.1.2/32"}, PubicKey: spp, PresharedKey: pshk, - MTU: 4096, - MTURandomRange: 1024, + MTU: mtu, + MTURandomRange: mtu / 2, UseZstd: true, }) tunnme, err := Create(&m, "192.168.1.3") @@ -152,32 +152,64 @@ func TestTunnelUDP(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) logrus.SetFormatter(&logFormat{enableColor: false}) - testTunnel(t, "udp", true, nil) // test plain text + testTunnel(t, "udp", true, nil, 4096) // test plain text - testTunnel(t, "udp", false, nil) // test normal + testTunnel(t, "udp", false, nil, 4096) // test normal var buf [32]byte _, err := rand.Read(buf[:]) if err != nil { panic(err) } - testTunnel(t, "udp", false, &buf) // test preshared + testTunnel(t, "udp", false, &buf, 4096) // test preshared +} + +func TestTunnelUDPSmallMTU(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetFormatter(&logFormat{enableColor: false}) + + testTunnel(t, "udp", true, nil, 1024) // test plain text + + testTunnel(t, "udp", false, nil, 1024) // test normal + + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + panic(err) + } + testTunnel(t, "udp", false, &buf, 1024) // test preshared } func TestTunnelTCP(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) logrus.SetFormatter(&logFormat{enableColor: false}) - testTunnel(t, "tcp", true, nil) // test plain text + testTunnel(t, "tcp", true, nil, 4096) // test plain text - testTunnel(t, "tcp", false, nil) // test normal + testTunnel(t, "tcp", false, nil, 4096) // test normal var buf [32]byte _, err := rand.Read(buf[:]) if err != nil { panic(err) } - testTunnel(t, "tcp", false, &buf) // test preshared + testTunnel(t, "tcp", false, &buf, 4096) // test preshared +} + +func TestTunnelTCPSmallMTU(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetFormatter(&logFormat{enableColor: false}) + + testTunnel(t, "tcp", true, nil, 1024) // test plain text + + testTunnel(t, "tcp", false, nil, 1024) // test normal + + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + panic(err) + } + testTunnel(t, "tcp", false, &buf, 1024) // test preshared } // logFormat specialize for go-cqhttp