1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-10 19:50:30 +08:00

fix(link): transfer fragmented packet

This commit is contained in:
源文雨
2024-07-17 15:43:44 +09:00
parent 06853c6552
commit cb2fe9bd21
5 changed files with 77 additions and 40 deletions

View File

@@ -12,6 +12,8 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
const PacketHeadLen = 60
var ( var (
ErrBadCRCChecksum = errors.New("bad crc checksum") ErrBadCRCChecksum = errors.New("bad crc checksum")
ErrDataLenLT60 = errors.New("data len < 60") ErrDataLenLT60 = errors.New("data len < 60")
@@ -105,7 +107,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
err = ErrDataLenLT60 err = ErrDataLenLT60
return return
} }
p.crc64 = binary.LittleEndian.Uint64(data[52:60]) p.crc64 = binary.LittleEndian.Uint64(data[52:PacketHeadLen])
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 { if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
err = ErrBadCRCChecksum err = ErrBadCRCChecksum
return return
@@ -144,7 +146,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
} }
if p.rembytes > 0 { if p.rembytes > 0 {
p.rembytes -= copy(p.data[flags.Offset():], data[60:]) p.rembytes -= copy(p.data[flags.Offset():], data[PacketHeadLen:])
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
} }
@@ -162,18 +164,19 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
} }
if src != nil { if src != nil {
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
p.Src = src p.Src = src
offset &= 0x1fff p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
if dontfrag {
offset |= 0x4000
}
if hasmore {
offset |= 0x2000
}
p.Flags = PacketFlags(offset)
} }
offset &= 0x1fff
if dontfrag {
offset |= 0x4000
}
if hasmore {
offset |= 0x2000
}
p.Flags = PacketFlags(offset)
return helper.OpenWriterF(func(w *helper.Writer) { return helper.OpenWriterF(func(w *helper.Writer) {
w.WriteUInt32(p.idxdatsz) w.WriteUInt32(p.idxdatsz)
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto)) w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))

View File

@@ -125,7 +125,7 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
var err error var err error
data, err := p.Decode(packet.CipherIndex(), addt, packet.Body()) data, err := p.Decode(packet.CipherIndex(), addt, packet.Body())
if err != nil { if err != nil {
logrus.Debugln("[listen] @", index, "drop invalid packet", ", key idx:", packet.CipherIndex(), "addt:", addt, "err:", err) logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", packet.CipherIndex(), "addt:", addt, "err:", err)
packet.Put() packet.Put()
return return
} }

View File

@@ -16,7 +16,7 @@ func (l *Link) Read() *head.Packet {
} }
func (m *Me) wait(data []byte) *head.Packet { func (m *Me) wait(data []byte) *head.Packet {
if len(data) < 60 { // not a valid packet if len(data) < head.PacketHeadLen { // not a valid packet
return nil return nil
} }
bound := 64 bound := 64
@@ -33,7 +33,7 @@ func (m *Me) wait(data []byte) *head.Packet {
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
return nil return nil
} }
crc := binary.LittleEndian.Uint64(data[52:60]) crc := binary.LittleEndian.Uint64(data[52:head.PacketHeadLen])
if m.recved.Get(crc) { // 是重放攻击 if m.recved.Get(crc) { // 是重放攻击
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
return nil return nil

View File

@@ -15,8 +15,8 @@ import (
) )
var ( var (
ErrDropBigDontFragTransPkt = errors.New("drop big don't fragmnet trans packet") ErrDropBigDontFragPkt = errors.New("drop big don't fragmnet packet")
ErrTTL = errors.New("ttl exceeded") ErrTTL = errors.New("ttl exceeded")
) )
// WriteAndPut 向 peer 发包并将包放回缓存池 // WriteAndPut 向 peer 发包并将包放回缓存池
@@ -32,7 +32,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
if !istransfer { if !istransfer {
l.encrypt(p, sndcnt, teatype) l.encrypt(p, sndcnt, teatype)
} }
delta := (int(mtu) - 60) & 0x0000fff8 delta := (int(mtu) - head.PacketHeadLen) & 0x0000fff8
if delta <= 0 { if delta <= 0 {
logrus.Warnln("[send] reset invalid data frag len", delta, "to 8") logrus.Warnln("[send] reset invalid data frag len", delta, "to 8")
delta = 8 delta = 8
@@ -42,7 +42,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false)
} }
if istransfer && p.Flags.DontFrag() && remlen > delta { if istransfer && p.Flags.DontFrag() && remlen > delta {
return 0, ErrDropBigDontFragTransPkt return 0, ErrDropBigDontFragPkt
} }
ttl := p.TTL ttl := p.TTL
totl := uint32(remlen) totl := uint32(remlen)
@@ -89,21 +89,25 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) {
} }
// write 向 peer 发一个包 // write 向 peer 发一个包
func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) { func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) {
peerep := l.endpoint
if peerep == nil {
return 0, errors.New("nil endpoint of " + p.Dst.String())
}
var d []byte var d []byte
var cl func() var cl func()
// TODO: now all packet allow frag, adapt to DF
if istransfer { if istransfer {
d, cl = p.Marshal(nil, teatype, additional, 0, 0, false, false) d, cl = p.Marshal(nil, 0, 0, 0, offset, false, hasmore)
} else { } else {
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore) d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
} }
if d == nil { if d == nil {
return 0, ErrTTL return 0, ErrTTL
} }
peerep := l.endpoint defer cl()
if peerep == nil {
return 0, errors.New("nil endpoint of " + p.Dst.String())
}
bound := 64 bound := 64
endl := "..." endl := "..."
if len(d) < bound { if len(d) < bound {
@@ -114,7 +118,5 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
d = l.me.xorenc(d) d = l.me.xorenc(d)
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
n, err = l.me.conn.WriteToPeer(d, peerep) return l.me.conn.WriteToPeer(d, peerep)
cl()
return
} }

View File

@@ -16,7 +16,7 @@ import (
"github.com/fumiama/WireGold/helper" "github.com/fumiama/WireGold/helper"
) )
func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) { func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte, mtu uint16) {
selfpk, err := curve.New(nil) selfpk, err := curve.New(nil)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -37,7 +37,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) {
PrivateKey: selfpk.Private(), PrivateKey: selfpk.Private(),
SrcPort: 1, SrcPort: 1,
DstPort: 1, DstPort: 1,
MTU: 4096, MTU: mtu,
}) })
defer m.Close() defer m.Close()
@@ -48,7 +48,7 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) {
PrivateKey: peerpk.Private(), PrivateKey: peerpk.Private(),
SrcPort: 1, SrcPort: 1,
DstPort: 1, DstPort: 1,
MTU: 4096, MTU: mtu,
}) })
defer p.Close() defer p.Close()
@@ -65,8 +65,8 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) {
AllowedIPs: []string{"192.168.1.3/32"}, AllowedIPs: []string{"192.168.1.3/32"},
PubicKey: ppp, PubicKey: ppp,
PresharedKey: pshk, PresharedKey: pshk,
MTU: 4096, MTU: mtu,
MTURandomRange: 1024, MTURandomRange: mtu / 2,
UseZstd: true, UseZstd: true,
}) })
p.AddPeer(&link.PeerConfig{ p.AddPeer(&link.PeerConfig{
@@ -75,8 +75,8 @@ func testTunnel(t *testing.T, nw string, isplain bool, pshk *[32]byte) {
AllowedIPs: []string{"192.168.1.2/32"}, AllowedIPs: []string{"192.168.1.2/32"},
PubicKey: spp, PubicKey: spp,
PresharedKey: pshk, PresharedKey: pshk,
MTU: 4096, MTU: mtu,
MTURandomRange: 1024, MTURandomRange: mtu / 2,
UseZstd: true, UseZstd: true,
}) })
tunnme, err := Create(&m, "192.168.1.3") tunnme, err := Create(&m, "192.168.1.3")
@@ -152,32 +152,64 @@ func TestTunnelUDP(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
logrus.SetFormatter(&logFormat{enableColor: false}) logrus.SetFormatter(&logFormat{enableColor: false})
testTunnel(t, "udp", true, nil) // test plain text testTunnel(t, "udp", true, nil, 4096) // test plain text
testTunnel(t, "udp", false, nil) // test normal testTunnel(t, "udp", false, nil, 4096) // test normal
var buf [32]byte var buf [32]byte
_, err := rand.Read(buf[:]) _, err := rand.Read(buf[:])
if err != nil { if err != nil {
panic(err) panic(err)
} }
testTunnel(t, "udp", false, &buf) // test preshared testTunnel(t, "udp", false, &buf, 4096) // test preshared
}
func TestTunnelUDPSmallMTU(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
logrus.SetFormatter(&logFormat{enableColor: false})
testTunnel(t, "udp", true, nil, 1024) // test plain text
testTunnel(t, "udp", false, nil, 1024) // test normal
var buf [32]byte
_, err := rand.Read(buf[:])
if err != nil {
panic(err)
}
testTunnel(t, "udp", false, &buf, 1024) // test preshared
} }
func TestTunnelTCP(t *testing.T) { func TestTunnelTCP(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
logrus.SetFormatter(&logFormat{enableColor: false}) logrus.SetFormatter(&logFormat{enableColor: false})
testTunnel(t, "tcp", true, nil) // test plain text testTunnel(t, "tcp", true, nil, 4096) // test plain text
testTunnel(t, "tcp", false, nil) // test normal testTunnel(t, "tcp", false, nil, 4096) // test normal
var buf [32]byte var buf [32]byte
_, err := rand.Read(buf[:]) _, err := rand.Read(buf[:])
if err != nil { if err != nil {
panic(err) panic(err)
} }
testTunnel(t, "tcp", false, &buf) // test preshared testTunnel(t, "tcp", false, &buf, 4096) // test preshared
}
func TestTunnelTCPSmallMTU(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
logrus.SetFormatter(&logFormat{enableColor: false})
testTunnel(t, "tcp", true, nil, 1024) // test plain text
testTunnel(t, "tcp", false, nil, 1024) // test normal
var buf [32]byte
_, err := rand.Read(buf[:])
if err != nil {
panic(err)
}
testTunnel(t, "tcp", false, &buf, 1024) // test preshared
} }
// logFormat specialize for go-cqhttp // logFormat specialize for go-cqhttp