mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-28 14:50:26 +08:00
optimize(head): packet encapsuling
This commit is contained in:
@@ -12,13 +12,44 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PacketFlags uint16
|
||||||
|
|
||||||
|
func (pf PacketFlags) IsValid() bool {
|
||||||
|
return pf&0x8000 == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf PacketFlags) DontFrag() bool {
|
||||||
|
return pf&0x4000 == 0x4000
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf PacketFlags) NoFrag() bool {
|
||||||
|
return pf == 0x4000
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf PacketFlags) IsSingle() bool {
|
||||||
|
return pf == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf PacketFlags) ZeroOffset() bool {
|
||||||
|
return pf&0x1fff == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf PacketFlags) Offset() uint16 {
|
||||||
|
return uint16(pf << 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flags extract flags from raw data
|
||||||
|
func Flags(data []byte) PacketFlags {
|
||||||
|
return PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
|
||||||
|
}
|
||||||
|
|
||||||
// Packet 是发送和接收的最小单位
|
// Packet 是发送和接收的最小单位
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
// TeaTypeDataSZ len(Data)
|
// idxdatsz len(Data)
|
||||||
// 高 5 位指定加密所用 key index
|
// 高 5 位指定加密所用 key index
|
||||||
// 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData
|
// 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData
|
||||||
// 不得超过 65507-head 字节
|
// 不得超过 65507-head 字节
|
||||||
TeaTypeDataSZ uint32
|
idxdatsz uint32
|
||||||
// Proto 详见 head
|
// Proto 详见 head
|
||||||
Proto uint8
|
Proto uint8
|
||||||
// TTL is time to live
|
// TTL is time to live
|
||||||
@@ -28,7 +59,7 @@ type Packet struct {
|
|||||||
// DstPort 目的端口
|
// DstPort 目的端口
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
// Flags 高3位为标志(xDM),低13位为分片偏移
|
// Flags 高3位为标志(xDM),低13位为分片偏移
|
||||||
Flags uint16
|
Flags PacketFlags
|
||||||
// Src 源 ip (ipv4)
|
// Src 源 ip (ipv4)
|
||||||
Src net.IP
|
Src net.IP
|
||||||
// Dst 目的 ip (ipv4)
|
// Dst 目的 ip (ipv4)
|
||||||
@@ -37,8 +68,8 @@ type Packet struct {
|
|||||||
// 生成时 Hash 全 0
|
// 生成时 Hash 全 0
|
||||||
// https://github.com/fumiama/blake2b-simd
|
// https://github.com/fumiama/blake2b-simd
|
||||||
Hash [32]byte
|
Hash [32]byte
|
||||||
// CRC64 包头字段的 checksum 值,可以认为在一定时间内唯一
|
// crc64 包头字段的 checksum 值,可以认为在一定时间内唯一
|
||||||
CRC64 uint64
|
crc64 uint64
|
||||||
// Data 承载的数据
|
// Data 承载的数据
|
||||||
Data []byte
|
Data []byte
|
||||||
// 记录还有多少字节未到达
|
// 记录还有多少字节未到达
|
||||||
@@ -64,15 +95,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
|||||||
err = errors.New("data len < 60")
|
err = errors.New("data len < 60")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != binary.LittleEndian.Uint64(data[52:60]) {
|
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
|
||||||
|
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
|
||||||
err = errors.New("bad crc checksum")
|
err = errors.New("bad crc checksum")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sz := p.TeaTypeDataSZ & 0x0000ffff
|
sz := p.idxdatsz & 0x0000ffff
|
||||||
if sz == 0 && len(p.Data) == 0 {
|
if sz == 0 && len(p.Data) == 0 {
|
||||||
p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4])
|
p.idxdatsz = binary.LittleEndian.Uint32(data[:4])
|
||||||
sz = p.TeaTypeDataSZ & 0x0000ffff
|
sz = p.idxdatsz & 0x0000ffff
|
||||||
if int(sz)+52 == len(data) {
|
if int(sz)+52 == len(data) {
|
||||||
p.Data = data[52:]
|
p.Data = data[52:]
|
||||||
p.rembytes = 0
|
p.rembytes = 0
|
||||||
@@ -87,20 +119,19 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
|||||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||||
}
|
}
|
||||||
|
|
||||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
|
||||||
|
|
||||||
if flags&0x1fff == 0 {
|
if flags.ZeroOffset() {
|
||||||
p.Flags = flags
|
p.Flags = flags
|
||||||
p.Src = make(net.IP, 4)
|
p.Src = make(net.IP, 4)
|
||||||
copy(p.Src, data[12:16])
|
copy(p.Src, data[12:16])
|
||||||
p.Dst = make(net.IP, 4)
|
p.Dst = make(net.IP, 4)
|
||||||
copy(p.Dst, data[16:20])
|
copy(p.Dst, data[16:20])
|
||||||
copy(p.Hash[:], data[20:52])
|
copy(p.Hash[:], data[20:52])
|
||||||
p.CRC64 = binary.LittleEndian.Uint64(data[52:60])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.rembytes > 0 {
|
if p.rembytes > 0 {
|
||||||
p.rembytes -= copy(p.Data[flags<<3:], data[60:])
|
p.rembytes -= copy(p.Data[flags.Offset():], data[60:])
|
||||||
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
|
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +149,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
|
|||||||
}
|
}
|
||||||
|
|
||||||
if src != nil {
|
if src != nil {
|
||||||
p.TeaTypeDataSZ = uint32(teatype)<<27 | (uint32(additional&0x07ff) << 16) | datasz&0xffff
|
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
|
||||||
p.Src = src
|
p.Src = src
|
||||||
offset &= 0x1fff
|
offset &= 0x1fff
|
||||||
if dontfrag {
|
if dontfrag {
|
||||||
@@ -127,15 +158,15 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
|
|||||||
if hasmore {
|
if hasmore {
|
||||||
offset |= 0x2000
|
offset |= 0x2000
|
||||||
}
|
}
|
||||||
p.Flags = offset
|
p.Flags = PacketFlags(offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
return helper.OpenWriterF(func(w *helper.Writer) {
|
return helper.OpenWriterF(func(w *helper.Writer) {
|
||||||
w.WriteUInt32(p.TeaTypeDataSZ)
|
w.WriteUInt32(p.idxdatsz)
|
||||||
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
|
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
|
||||||
w.WriteUInt16(p.SrcPort)
|
w.WriteUInt16(p.SrcPort)
|
||||||
w.WriteUInt16(p.DstPort)
|
w.WriteUInt16(p.DstPort)
|
||||||
w.WriteUInt16(p.Flags)
|
w.WriteUInt16(uint16(p.Flags))
|
||||||
w.Write(p.Src.To4())
|
w.Write(p.Src.To4())
|
||||||
w.Write(p.Dst.To4())
|
w.Write(p.Dst.To4())
|
||||||
w.Write(p.Hash[:])
|
w.Write(p.Hash[:])
|
||||||
@@ -171,7 +202,17 @@ func (p *Packet) IsVaildHash() bool {
|
|||||||
|
|
||||||
// AdditionalData 获得 packet 的 additionalData
|
// AdditionalData 获得 packet 的 additionalData
|
||||||
func (p *Packet) AdditionalData() uint16 {
|
func (p *Packet) AdditionalData() uint16 {
|
||||||
return uint16((p.TeaTypeDataSZ >> 16) & 0x07ff)
|
return uint16((p.idxdatsz >> 16) & 0x07ff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CipherIndex packet 加密使用的密钥集目录
|
||||||
|
func (p *Packet) CipherIndex() uint8 {
|
||||||
|
return uint8(p.idxdatsz >> 27)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len is packet size
|
||||||
|
func (p *Packet) Len() int {
|
||||||
|
return int(p.idxdatsz & 0xffff)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put 将自己放回池中
|
// Put 将自己放回池中
|
||||||
|
|||||||
68
gold/head/packet_test.go
Normal file
68
gold/head/packet_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package head
|
||||||
|
|
||||||
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMarshalUnmarshal(t *testing.T) {
|
||||||
|
data := make([]byte, 4096)
|
||||||
|
_, err := crand.Read(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for i := 0; i < 0x7ff; i++ {
|
||||||
|
proto := uint8(rand.Intn(255))
|
||||||
|
teatype := uint8(rand.Intn(32))
|
||||||
|
srcPort := uint16(rand.Intn(65535))
|
||||||
|
dstPort := uint16(rand.Intn(65535))
|
||||||
|
src := make(net.IP, 4)
|
||||||
|
_, err = crand.Read(src)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
dst := make(net.IP, 4)
|
||||||
|
_, err = crand.Read(dst)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p := NewPacket(proto, srcPort, dst, dstPort, data)
|
||||||
|
p.FillHash()
|
||||||
|
d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false)
|
||||||
|
p = SelectPacket()
|
||||||
|
ok, err := p.Unmarshal(d)
|
||||||
|
cl()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !p.IsVaildHash() {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if p.Proto != proto {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if p.CipherIndex() != teatype {
|
||||||
|
t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex())
|
||||||
|
}
|
||||||
|
if p.SrcPort != srcPort {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if p.DstPort != dstPort {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if !p.Src.Equal(src) {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if !p.Dst.Equal(dst) {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
if p.AdditionalData() != uint16(i) {
|
||||||
|
t.Fatal("index", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ func SelectPacket() *Packet {
|
|||||||
|
|
||||||
// PutPacket 将 Packet 放回池中
|
// PutPacket 将 Packet 放回池中
|
||||||
func PutPacket(p *Packet) {
|
func PutPacket(p *Packet) {
|
||||||
p.TeaTypeDataSZ = 0
|
p.idxdatsz = 0
|
||||||
p.Data = nil
|
p.Data = nil
|
||||||
packetPool.Put(p)
|
packetPool.Put(p)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,8 +90,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
|
|||||||
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
|
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
|
||||||
defer finish()
|
defer finish()
|
||||||
defer logrus.Debugln("[listen] unlock index", index)
|
defer logrus.Debugln("[listen] unlock index", index)
|
||||||
sz := packet.TeaTypeDataSZ & 0x0000ffff
|
r := packet.Len() - len(packet.Data)
|
||||||
r := int(sz) - len(packet.Data)
|
|
||||||
if r > 0 {
|
if r > 0 {
|
||||||
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
|
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
|
||||||
packet.Put()
|
packet.Put()
|
||||||
@@ -112,7 +111,7 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
|
|||||||
case p.IsToMe(packet.Dst):
|
case p.IsToMe(packet.Dst):
|
||||||
addt := packet.AdditionalData()
|
addt := packet.AdditionalData()
|
||||||
var err error
|
var err error
|
||||||
packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
|
packet.Data, err = p.Decode(packet.CipherIndex(), addt, packet.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
|
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
|
||||||
packet.Put()
|
packet.Put()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package link
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"strconv"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/gold/head"
|
"github.com/fumiama/WireGold/gold/head"
|
||||||
@@ -27,17 +28,18 @@ func (m *Me) wait(data []byte) *head.Packet {
|
|||||||
logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl)
|
logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl)
|
||||||
data = m.xordec(data)
|
data = m.xordec(data)
|
||||||
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
|
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
|
||||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
flags := head.Flags(data)
|
||||||
if flags&0x8000 != 0 { // not a valid packet
|
if !flags.IsValid() {
|
||||||
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
crc := binary.LittleEndian.Uint64(data[52:60])
|
crc := binary.LittleEndian.Uint64(data[52:60])
|
||||||
if m.recved.Get(crc) { // 是重放攻击
|
if m.recved.Get(crc) { // 是重放攻击
|
||||||
|
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
|
||||||
if flags == 0 || flags == 0x4000 {
|
if flags.IsSingle() || flags.NoFrag() {
|
||||||
h := head.SelectPacket()
|
h := head.SelectPacket()
|
||||||
_, err := h.Unmarshal(data)
|
_, err := h.Unmarshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
|
|||||||
if len(p.Data) <= delta {
|
if len(p.Data) <= delta {
|
||||||
return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false)
|
return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false)
|
||||||
}
|
}
|
||||||
if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta {
|
if istransfer && p.Flags.DontFrag() && len(p.Data) > delta {
|
||||||
return 0, errors.New("drop don't fragmnet big trans packet")
|
return 0, errors.New("drop don't fragmnet big trans packet")
|
||||||
}
|
}
|
||||||
data := p.Data
|
data := p.Data
|
||||||
|
|||||||
Reference in New Issue
Block a user