mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-09 10:26:02 +08:00
optimize(head): packet encapsuling
This commit is contained in:
@@ -12,13 +12,44 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type PacketFlags uint16
|
||||
|
||||
func (pf PacketFlags) IsValid() bool {
|
||||
return pf&0x8000 == 0
|
||||
}
|
||||
|
||||
func (pf PacketFlags) DontFrag() bool {
|
||||
return pf&0x4000 == 0x4000
|
||||
}
|
||||
|
||||
func (pf PacketFlags) NoFrag() bool {
|
||||
return pf == 0x4000
|
||||
}
|
||||
|
||||
func (pf PacketFlags) IsSingle() bool {
|
||||
return pf == 0
|
||||
}
|
||||
|
||||
func (pf PacketFlags) ZeroOffset() bool {
|
||||
return pf&0x1fff == 0
|
||||
}
|
||||
|
||||
func (pf PacketFlags) Offset() uint16 {
|
||||
return uint16(pf << 3)
|
||||
}
|
||||
|
||||
// Flags extract flags from raw data
|
||||
func Flags(data []byte) PacketFlags {
|
||||
return PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
|
||||
}
|
||||
|
||||
// Packet 是发送和接收的最小单位
|
||||
type Packet struct {
|
||||
// TeaTypeDataSZ len(Data)
|
||||
// idxdatsz len(Data)
|
||||
// 高 5 位指定加密所用 key index
|
||||
// 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData
|
||||
// 不得超过 65507-head 字节
|
||||
TeaTypeDataSZ uint32
|
||||
idxdatsz uint32
|
||||
// Proto 详见 head
|
||||
Proto uint8
|
||||
// TTL is time to live
|
||||
@@ -28,7 +59,7 @@ type Packet struct {
|
||||
// DstPort 目的端口
|
||||
DstPort uint16
|
||||
// Flags 高3位为标志(xDM),低13位为分片偏移
|
||||
Flags uint16
|
||||
Flags PacketFlags
|
||||
// Src 源 ip (ipv4)
|
||||
Src net.IP
|
||||
// Dst 目的 ip (ipv4)
|
||||
@@ -37,8 +68,8 @@ type Packet struct {
|
||||
// 生成时 Hash 全 0
|
||||
// https://github.com/fumiama/blake2b-simd
|
||||
Hash [32]byte
|
||||
// CRC64 包头字段的 checksum 值,可以认为在一定时间内唯一
|
||||
CRC64 uint64
|
||||
// crc64 包头字段的 checksum 值,可以认为在一定时间内唯一
|
||||
crc64 uint64
|
||||
// Data 承载的数据
|
||||
Data []byte
|
||||
// 记录还有多少字节未到达
|
||||
@@ -64,15 +95,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
err = errors.New("data len < 60")
|
||||
return
|
||||
}
|
||||
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != binary.LittleEndian.Uint64(data[52:60]) {
|
||||
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
|
||||
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
|
||||
err = errors.New("bad crc checksum")
|
||||
return
|
||||
}
|
||||
|
||||
sz := p.TeaTypeDataSZ & 0x0000ffff
|
||||
sz := p.idxdatsz & 0x0000ffff
|
||||
if sz == 0 && len(p.Data) == 0 {
|
||||
p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4])
|
||||
sz = p.TeaTypeDataSZ & 0x0000ffff
|
||||
p.idxdatsz = binary.LittleEndian.Uint32(data[:4])
|
||||
sz = p.idxdatsz & 0x0000ffff
|
||||
if int(sz)+52 == len(data) {
|
||||
p.Data = data[52:]
|
||||
p.rembytes = 0
|
||||
@@ -87,20 +119,19 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
}
|
||||
|
||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
|
||||
|
||||
if flags&0x1fff == 0 {
|
||||
if flags.ZeroOffset() {
|
||||
p.Flags = flags
|
||||
p.Src = make(net.IP, 4)
|
||||
copy(p.Src, data[12:16])
|
||||
p.Dst = make(net.IP, 4)
|
||||
copy(p.Dst, data[16:20])
|
||||
copy(p.Hash[:], data[20:52])
|
||||
p.CRC64 = binary.LittleEndian.Uint64(data[52:60])
|
||||
}
|
||||
|
||||
if p.rembytes > 0 {
|
||||
p.rembytes -= copy(p.Data[flags<<3:], data[60:])
|
||||
p.rembytes -= copy(p.Data[flags.Offset():], data[60:])
|
||||
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
|
||||
}
|
||||
|
||||
@@ -118,7 +149,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
|
||||
}
|
||||
|
||||
if src != nil {
|
||||
p.TeaTypeDataSZ = uint32(teatype)<<27 | (uint32(additional&0x07ff) << 16) | datasz&0xffff
|
||||
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
|
||||
p.Src = src
|
||||
offset &= 0x1fff
|
||||
if dontfrag {
|
||||
@@ -127,15 +158,15 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
|
||||
if hasmore {
|
||||
offset |= 0x2000
|
||||
}
|
||||
p.Flags = offset
|
||||
p.Flags = PacketFlags(offset)
|
||||
}
|
||||
|
||||
return helper.OpenWriterF(func(w *helper.Writer) {
|
||||
w.WriteUInt32(p.TeaTypeDataSZ)
|
||||
w.WriteUInt32(p.idxdatsz)
|
||||
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
|
||||
w.WriteUInt16(p.SrcPort)
|
||||
w.WriteUInt16(p.DstPort)
|
||||
w.WriteUInt16(p.Flags)
|
||||
w.WriteUInt16(uint16(p.Flags))
|
||||
w.Write(p.Src.To4())
|
||||
w.Write(p.Dst.To4())
|
||||
w.Write(p.Hash[:])
|
||||
@@ -171,7 +202,17 @@ func (p *Packet) IsVaildHash() bool {
|
||||
|
||||
// AdditionalData 获得 packet 的 additionalData
|
||||
func (p *Packet) AdditionalData() uint16 {
|
||||
return uint16((p.TeaTypeDataSZ >> 16) & 0x07ff)
|
||||
return uint16((p.idxdatsz >> 16) & 0x07ff)
|
||||
}
|
||||
|
||||
// CipherIndex packet 加密使用的密钥集目录
|
||||
func (p *Packet) CipherIndex() uint8 {
|
||||
return uint8(p.idxdatsz >> 27)
|
||||
}
|
||||
|
||||
// Len is packet size
|
||||
func (p *Packet) Len() int {
|
||||
return int(p.idxdatsz & 0xffff)
|
||||
}
|
||||
|
||||
// Put 将自己放回池中
|
||||
|
||||
68
gold/head/packet_test.go
Normal file
68
gold/head/packet_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package head
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalUnmarshal(t *testing.T) {
|
||||
data := make([]byte, 4096)
|
||||
_, err := crand.Read(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i < 0x7ff; i++ {
|
||||
proto := uint8(rand.Intn(255))
|
||||
teatype := uint8(rand.Intn(32))
|
||||
srcPort := uint16(rand.Intn(65535))
|
||||
dstPort := uint16(rand.Intn(65535))
|
||||
src := make(net.IP, 4)
|
||||
_, err = crand.Read(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dst := make(net.IP, 4)
|
||||
_, err = crand.Read(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p := NewPacket(proto, srcPort, dst, dstPort, data)
|
||||
p.FillHash()
|
||||
d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false)
|
||||
p = SelectPacket()
|
||||
ok, err := p.Unmarshal(d)
|
||||
cl()
|
||||
if !ok {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !p.IsVaildHash() {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if p.Proto != proto {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if p.CipherIndex() != teatype {
|
||||
t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex())
|
||||
}
|
||||
if p.SrcPort != srcPort {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if p.DstPort != dstPort {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if !p.Src.Equal(src) {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if !p.Dst.Equal(dst) {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
if p.AdditionalData() != uint16(i) {
|
||||
t.Fatal("index", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ func SelectPacket() *Packet {
|
||||
|
||||
// PutPacket 将 Packet 放回池中
|
||||
func PutPacket(p *Packet) {
|
||||
p.TeaTypeDataSZ = 0
|
||||
p.idxdatsz = 0
|
||||
p.Data = nil
|
||||
packetPool.Put(p)
|
||||
}
|
||||
|
||||
@@ -90,8 +90,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
|
||||
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
|
||||
defer finish()
|
||||
defer logrus.Debugln("[listen] unlock index", index)
|
||||
sz := packet.TeaTypeDataSZ & 0x0000ffff
|
||||
r := int(sz) - len(packet.Data)
|
||||
r := packet.Len() - len(packet.Data)
|
||||
if r > 0 {
|
||||
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
|
||||
packet.Put()
|
||||
@@ -112,7 +111,7 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
|
||||
case p.IsToMe(packet.Dst):
|
||||
addt := packet.AdditionalData()
|
||||
var err error
|
||||
packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
|
||||
packet.Data, err = p.Decode(packet.CipherIndex(), addt, packet.Data)
|
||||
if err != nil {
|
||||
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
|
||||
packet.Put()
|
||||
|
||||
@@ -3,6 +3,7 @@ package link
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/head"
|
||||
@@ -27,17 +28,18 @@ func (m *Me) wait(data []byte) *head.Packet {
|
||||
logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl)
|
||||
data = m.xordec(data)
|
||||
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
|
||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||
if flags&0x8000 != 0 { // not a valid packet
|
||||
flags := head.Flags(data)
|
||||
if !flags.IsValid() {
|
||||
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
||||
return nil
|
||||
}
|
||||
crc := binary.LittleEndian.Uint64(data[52:60])
|
||||
if m.recved.Get(crc) { // 是重放攻击
|
||||
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
|
||||
return nil
|
||||
}
|
||||
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
||||
if flags == 0 || flags == 0x4000 {
|
||||
if flags.IsSingle() || flags.NoFrag() {
|
||||
h := head.SelectPacket()
|
||||
_, err := h.Unmarshal(data)
|
||||
if err != nil {
|
||||
|
||||
@@ -35,7 +35,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
if len(p.Data) <= delta {
|
||||
return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false)
|
||||
}
|
||||
if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta {
|
||||
if istransfer && p.Flags.DontFrag() && len(p.Data) > delta {
|
||||
return 0, errors.New("drop don't fragmnet big trans packet")
|
||||
}
|
||||
data := p.Data
|
||||
|
||||
Reference in New Issue
Block a user