From 6ede65bf28b18c88f7214b73e4011d6634a58a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 13 Jul 2024 00:22:20 +0900 Subject: [PATCH] optimize(head): packet encapsuling --- gold/head/packet.go | 77 ++++++++++++++++++++++++++++++---------- gold/head/packet_test.go | 68 +++++++++++++++++++++++++++++++++++ gold/head/pool.go | 2 +- gold/link/listen.go | 5 ++- gold/link/recv.go | 8 +++-- gold/link/send.go | 2 +- 6 files changed, 136 insertions(+), 26 deletions(-) create mode 100644 gold/head/packet_test.go diff --git a/gold/head/packet.go b/gold/head/packet.go index 47adb72..3d7d1fd 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -12,13 +12,44 @@ import ( "github.com/sirupsen/logrus" ) +type PacketFlags uint16 + +func (pf PacketFlags) IsValid() bool { + return pf&0x8000 == 0 +} + +func (pf PacketFlags) DontFrag() bool { + return pf&0x4000 == 0x4000 +} + +func (pf PacketFlags) NoFrag() bool { + return pf == 0x4000 +} + +func (pf PacketFlags) IsSingle() bool { + return pf == 0 +} + +func (pf PacketFlags) ZeroOffset() bool { + return pf&0x1fff == 0 +} + +func (pf PacketFlags) Offset() uint16 { + return uint16(pf << 3) +} + +// Flags extract flags from raw data +func Flags(data []byte) PacketFlags { + return PacketFlags(binary.LittleEndian.Uint16(data[10:12])) +} + // Packet 是发送和接收的最小单位 type Packet struct { - // TeaTypeDataSZ len(Data) + // idxdatsz len(Data) // 高 5 位指定加密所用 key index // 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData // 不得超过 65507-head 字节 - TeaTypeDataSZ uint32 + idxdatsz uint32 // Proto 详见 head Proto uint8 // TTL is time to live @@ -28,7 +59,7 @@ type Packet struct { // DstPort 目的端口 DstPort uint16 // Flags 高3位为标志(xDM),低13位为分片偏移 - Flags uint16 + Flags PacketFlags // Src 源 ip (ipv4) Src net.IP // Dst 目的 ip (ipv4) @@ -37,8 +68,8 @@ type Packet struct { // 生成时 Hash 全 0 // https://github.com/fumiama/blake2b-simd Hash [32]byte - // CRC64 包头字段的 checksum 值,可以认为在一定时间内唯一 - CRC64 uint64 + // crc64 包头字段的 checksum 值,可以认为在一定时间内唯一 + crc64 uint64 // Data 承载的数据 Data []byte // 记录还有多少字节未到达 @@ -64,15 +95,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { err = errors.New("data len < 60") return } - if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != binary.LittleEndian.Uint64(data[52:60]) { + p.crc64 = binary.LittleEndian.Uint64(data[52:60]) + if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 { err = errors.New("bad crc checksum") return } - sz := p.TeaTypeDataSZ & 0x0000ffff + sz := p.idxdatsz & 0x0000ffff if sz == 0 && len(p.Data) == 0 { - p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4]) - sz = p.TeaTypeDataSZ & 0x0000ffff + p.idxdatsz = binary.LittleEndian.Uint32(data[:4]) + sz = p.idxdatsz & 0x0000ffff if int(sz)+52 == len(data) { p.Data = data[52:] p.rembytes = 0 @@ -87,20 +119,19 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { p.DstPort = binary.LittleEndian.Uint16(data[8:10]) } - flags := binary.LittleEndian.Uint16(data[10:12]) + flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12])) - if flags&0x1fff == 0 { + if flags.ZeroOffset() { p.Flags = flags p.Src = make(net.IP, 4) copy(p.Src, data[12:16]) p.Dst = make(net.IP, 4) copy(p.Dst, data[16:20]) copy(p.Hash[:], data[20:52]) - p.CRC64 = binary.LittleEndian.Uint64(data[52:60]) } if p.rembytes > 0 { - p.rembytes -= copy(p.Data[flags<<3:], data[60:]) + p.rembytes -= copy(p.Data[flags.Offset():], data[60:]) logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) } @@ -118,7 +149,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui } if src != nil { - p.TeaTypeDataSZ = uint32(teatype)<<27 | (uint32(additional&0x07ff) << 16) | datasz&0xffff + p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff p.Src = src offset &= 0x1fff if dontfrag { @@ -127,15 +158,15 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui if hasmore { offset |= 0x2000 } - p.Flags = offset + p.Flags = PacketFlags(offset) } return helper.OpenWriterF(func(w *helper.Writer) { - w.WriteUInt32(p.TeaTypeDataSZ) + w.WriteUInt32(p.idxdatsz) w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto)) w.WriteUInt16(p.SrcPort) w.WriteUInt16(p.DstPort) - w.WriteUInt16(p.Flags) + w.WriteUInt16(uint16(p.Flags)) w.Write(p.Src.To4()) w.Write(p.Dst.To4()) w.Write(p.Hash[:]) @@ -171,7 +202,17 @@ func (p *Packet) IsVaildHash() bool { // AdditionalData 获得 packet 的 additionalData func (p *Packet) AdditionalData() uint16 { - return uint16((p.TeaTypeDataSZ >> 16) & 0x07ff) + return uint16((p.idxdatsz >> 16) & 0x07ff) +} + +// CipherIndex packet 加密使用的密钥集目录 +func (p *Packet) CipherIndex() uint8 { + return uint8(p.idxdatsz >> 27) +} + +// Len is packet size +func (p *Packet) Len() int { + return int(p.idxdatsz & 0xffff) } // Put 将自己放回池中 diff --git a/gold/head/packet_test.go b/gold/head/packet_test.go new file mode 100644 index 0000000..b19de4c --- /dev/null +++ b/gold/head/packet_test.go @@ -0,0 +1,68 @@ +package head + +import ( + crand "crypto/rand" + "math/rand" + "net" + "testing" +) + +func TestMarshalUnmarshal(t *testing.T) { + data := make([]byte, 4096) + _, err := crand.Read(data) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 0x7ff; i++ { + proto := uint8(rand.Intn(255)) + teatype := uint8(rand.Intn(32)) + srcPort := uint16(rand.Intn(65535)) + dstPort := uint16(rand.Intn(65535)) + src := make(net.IP, 4) + _, err = crand.Read(src) + if err != nil { + t.Fatal(err) + } + dst := make(net.IP, 4) + _, err = crand.Read(dst) + if err != nil { + t.Fatal(err) + } + p := NewPacket(proto, srcPort, dst, dstPort, data) + p.FillHash() + d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false) + p = SelectPacket() + ok, err := p.Unmarshal(d) + cl() + if !ok { + t.Fatal("index", i) + } + if err != nil { + t.Fatal(err) + } + if !p.IsVaildHash() { + t.Fatal("index", i) + } + if p.Proto != proto { + t.Fatal("index", i) + } + if p.CipherIndex() != teatype { + t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex()) + } + if p.SrcPort != srcPort { + t.Fatal("index", i) + } + if p.DstPort != dstPort { + t.Fatal("index", i) + } + if !p.Src.Equal(src) { + t.Fatal("index", i) + } + if !p.Dst.Equal(dst) { + t.Fatal("index", i) + } + if p.AdditionalData() != uint16(i) { + t.Fatal("index", i) + } + } +} diff --git a/gold/head/pool.go b/gold/head/pool.go index 0a90123..98763cb 100644 --- a/gold/head/pool.go +++ b/gold/head/pool.go @@ -15,7 +15,7 @@ func SelectPacket() *Packet { // PutPacket 将 Packet 放回池中 func PutPacket(p *Packet) { - p.TeaTypeDataSZ = 0 + p.idxdatsz = 0 p.Data = nil packetPool.Put(p) } diff --git a/gold/link/listen.go b/gold/link/listen.go index e8933eb..8c34cef 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -90,8 +90,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) { func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) { defer finish() defer logrus.Debugln("[listen] unlock index", index) - sz := packet.TeaTypeDataSZ & 0x0000ffff - r := int(sz) - len(packet.Data) + r := packet.Len() - len(packet.Data) if r > 0 { logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it") packet.Put() @@ -112,7 +111,7 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin case p.IsToMe(packet.Dst): addt := packet.AdditionalData() var err error - packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data) + packet.Data, err = p.Decode(packet.CipherIndex(), addt, packet.Data) if err != nil { logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err) packet.Put() diff --git a/gold/link/recv.go b/gold/link/recv.go index 50633a7..591a0ab 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -3,6 +3,7 @@ package link import ( "encoding/binary" "encoding/hex" + "strconv" "unsafe" "github.com/fumiama/WireGold/gold/head" @@ -27,17 +28,18 @@ func (m *Me) wait(data []byte) *head.Packet { logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl) data = m.xordec(data) logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl) - flags := binary.LittleEndian.Uint16(data[10:12]) - if flags&0x8000 != 0 { // not a valid packet + flags := head.Flags(data) + if !flags.IsValid() { logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) return nil } crc := binary.LittleEndian.Uint64(data[52:60]) if m.recved.Get(crc) { // 是重放攻击 + logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) return nil } logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) - if flags == 0 || flags == 0x4000 { + if flags.IsSingle() || flags.NoFrag() { h := head.SelectPacket() _, err := h.Unmarshal(data) if err != nil { diff --git a/gold/link/send.go b/gold/link/send.go index 791f497..f7d04fa 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -35,7 +35,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { if len(p.Data) <= delta { return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false) } - if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta { + if istransfer && p.Flags.DontFrag() && len(p.Data) > delta { return 0, errors.New("drop don't fragmnet big trans packet") } data := p.Data