1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-26 22:00:27 +08:00

feat: impl. new protol design & new head

This commit is contained in:
源文雨
2025-03-12 22:20:02 +09:00
parent 60209117b7
commit f4fd9b1423
49 changed files with 1643 additions and 1137 deletions

73
gold/head/box.go Normal file
View File

@@ -0,0 +1,73 @@
package head
import (
"bytes"
"encoding/binary"
"encoding/hex"
"unsafe"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/internal/algo"
"github.com/fumiama/WireGold/internal/bin"
)
// PreCRC64 calculate crc64 checksum without idxdatsz.
func (p *Packet) PreCRC64() (crc uint64) {
w := bin.SelectWriter()
if bin.IsLittleEndian {
w.Write((*[PacketHeadNoCRCLen]byte)(
(unsafe.Pointer)(p),
)[:])
} else {
w.WriteUInt32(p.idxdatsz)
w.WriteUInt32(uint32(p.randn))
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
w.WriteUInt16(p.SrcPort)
w.WriteUInt16(p.DstPort)
w.WriteUInt16(p.Offset)
w.Write(p.src[:])
w.Write(p.dst[:])
}
w.P(func(b *pbuf.Buffer) {
crc = algo.MD5Hash8(b.Bytes()[PacketHeadPreCRCIdx:])
if config.ShowDebugLog {
logrus.Debugf(
"[box] calc pre-crc64 %016x, dat %s", crc,
hex.EncodeToString(b.Bytes()[PacketHeadPreCRCIdx:]),
)
}
})
return
}
// WriteHeaderTo write header bytes to buf
// with crc64 checksum.
func (p *Packet) WriteHeaderTo(buf *bytes.Buffer) {
if bin.IsLittleEndian {
buf.Write((*[PacketHeadNoCRCLen]byte)(
(unsafe.Pointer)(p),
)[:])
p.md5h8rem = int64(algo.MD5Hash8(buf.Bytes()))
binary.Write(buf, binary.LittleEndian, p.md5h8rem)
return
}
w := bin.SelectWriter()
w.WriteUInt32(p.idxdatsz)
w.WriteUInt32(uint32(p.randn))
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
w.WriteUInt16(p.SrcPort)
w.WriteUInt16(p.DstPort)
w.WriteUInt16(p.Offset)
w.Write(p.src[:])
w.Write(p.dst[:])
w.P(func(b *pbuf.Buffer) {
p.md5h8rem = int64(algo.MD5Hash8(b.Bytes()))
})
w.WriteUInt64(uint64(p.md5h8rem))
w.P(func(b *pbuf.Buffer) {
buf.ReadFrom(b)
})
}

242
gold/head/builder.go Normal file
View File

@@ -0,0 +1,242 @@
package head
import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"net"
"strconv"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/internal/algo"
"github.com/fumiama/WireGold/internal/bin"
"github.com/fumiama/WireGold/internal/file"
)
type (
HeaderBuilder PacketItem
DataBuilder PacketItem
PacketBuilder PacketItem
)
func NewPacketBuilder() *HeaderBuilder {
p := selectPacket()
p.P(func(ub *PacketBuf) {
err := binary.Read(
rand.Reader, binary.LittleEndian, &ub.DAT.randn,
)
if err != nil {
panic(err)
}
})
return (*HeaderBuilder)(p)
}
func (pb *HeaderBuilder) p(f func(*PacketBuf)) *HeaderBuilder {
(*PacketItem)(pb).P(f)
return pb
}
func (pb *HeaderBuilder) Proto(proto uint8) *HeaderBuilder {
return pb.p(func(ub *PacketBuf) {
ub.DAT.Proto |= FlagsProto(proto) & protobit
})
}
func (pb *HeaderBuilder) TTL(ttl uint8) *HeaderBuilder {
return pb.p(func(ub *PacketBuf) {
ub.DAT.TTL = ttl
})
}
func (pb *HeaderBuilder) Src(ip net.IP, p uint16) *HeaderBuilder {
return pb.p(func(ub *PacketBuf) {
copy(ub.DAT.src[:], ip.To4())
ub.DAT.SrcPort = p
})
}
func (pb *HeaderBuilder) Dst(ip net.IP, p uint16) *HeaderBuilder {
return pb.p(func(ub *PacketBuf) {
copy(ub.DAT.dst[:], ip.To4())
ub.DAT.DstPort = p
})
}
func (pb *HeaderBuilder) With(data []byte) *DataBuilder {
return (*DataBuilder)(pb.p(func(ub *PacketBuf) {
// header crc64 except idxdatasz
ub.DAT.md5h8rem = int64(ub.DAT.PreCRC64())
// plain data
ub.Buffer.Write(data)
if config.ShowDebugLog {
logrus.Debugln(file.Header(), strconv.FormatUint(uint64(ub.DAT.md5h8rem), 16), "build with data", file.ToLimitHexString(data, 64))
}
}))
}
func (pb *DataBuilder) p(f func(*PacketBuf)) *DataBuilder {
(*PacketItem)(pb).P(f)
return pb
}
func (pb *DataBuilder) Zstd() *DataBuilder {
return pb.p(func(ub *PacketBuf) {
data := algo.EncodeZstd(ub.Bytes())
ub.Reset()
data.V(func(b []byte) { ub.Write(b) })
if config.ShowDebugLog {
logrus.Debugln(file.Header(), strconv.FormatUint(uint64(ub.DAT.md5h8rem), 16), "data after zstd", file.ToLimitHexString(ub.Bytes(), 64))
}
})
}
func (pb *DataBuilder) Hash() *DataBuilder {
return (*DataBuilder)(pb.p(func(ub *PacketBuf) {
ub.DAT.hash = algo.Blake2bHash8(
uint64(ub.DAT.md5h8rem), ub.Bytes(),
)
}))
}
func (pb *DataBuilder) tea(typ uint8) *DataBuilder {
return pb.p(func(ub *PacketBuf) {
ub.DAT.idxdatsz |= (uint32(typ) << 27)
})
}
func (pb *DataBuilder) additional(additional uint16) *DataBuilder {
return pb.p(func(ub *PacketBuf) {
ub.DAT.idxdatsz |= (uint32(additional&0x07ff) << 16)
})
}
func (pb *DataBuilder) Seal(aead cipher.AEAD, teatyp uint8, additional uint16) *PacketBuilder {
return (*PacketBuilder)(pb.tea(teatyp).additional(additional).
p(func(ub *PacketBuf) {
// encrypted data: chacha20(hash + plain)
w := bin.SelectWriter()
w.WriteUInt64(ub.DAT.hash)
w.Write(ub.Bytes())
w.P(func(b *pbuf.Buffer) {
data := algo.EncodeAEAD(aead, additional, b.Bytes())
ub.Reset()
data.V(func(b []byte) { ub.Write(b) })
})
}))
}
func (pb *DataBuilder) Plain(teatyp uint8, additional uint16) *PacketBuilder {
return (*PacketBuilder)(pb.tea(teatyp).additional(additional).
p(func(ub *PacketBuf) {
w := bin.SelectWriter()
w.WriteUInt64(ub.DAT.hash)
w.Write(ub.Bytes())
w.P(func(b *pbuf.Buffer) {
ub.Reset()
ub.ReadFrom(b)
})
}))
}
func (pb *DataBuilder) Trans(teatyp uint8, additional uint16) *PacketBuilder {
return (*PacketBuilder)(pb.tea(teatyp).additional(additional))
}
func (pb *PacketBuilder) copy() *PacketBuilder {
return (*PacketBuilder)((*PacketItem)(pb).Copy())
}
func (pb *PacketBuilder) p(f func(*PacketBuf)) *PacketBuilder {
(*PacketItem)(pb).P(f)
return pb
}
// datasize fill encrypted datasize by calling data.Len().
func (pb *PacketBuilder) datasize() *PacketBuilder {
return pb.p(func(ub *PacketBuf) {
l := uint32(ub.Len()) & 0xffff
ub.DAT.idxdatsz |= l
})
}
func (pb *PacketBuilder) noFrag(on bool) *PacketBuilder {
return pb.p(func(ub *PacketBuf) {
if on {
ub.DAT.Proto |= nofragbit
} else {
ub.DAT.Proto &= ^nofragbit
}
})
}
func (pb *PacketBuilder) hasMore(on bool) *PacketBuilder {
return pb.p(func(ub *PacketBuf) {
if on {
ub.DAT.Proto |= hasmorebit
} else {
ub.DAT.Proto &= ^hasmorebit
}
})
}
func (pb *PacketBuilder) offset(off uint16) *PacketBuilder {
return pb.p(func(ub *PacketBuf) {
ub.DAT.Offset = off
})
}
// Split mtu based on the total len, which includes
// header and body and padding after outer xor.
func (pb *PacketBuilder) Split(mtu int, nofrag bool) (pbs []PacketBytes) {
pb.datasize().p(func(ub *PacketBuf) {
bodylen := ub.Len()
datalen := bodylen + int(PacketHeadLen)
udplen := algo.EncodeXORLen(datalen)
if udplen <= mtu { // can be sent in a single packet
pbs = []PacketBytes{
pbuf.BufferItemToBytes((*PacketItem)(
pb.copy().noFrag(nofrag).hasMore(false).offset(0),
)),
}
return
}
if nofrag { // drop oversized packet
return
}
pb.noFrag(false).hasMore(true)
datalim := mtu - 9 - int(PacketHeadLen)
n := bodylen / datalim
r := bodylen % datalim
if r > 0 {
n++
}
pbs = make([]PacketBytes, n)
for i := 0; i < n; i++ {
a, b := i*datalim, (i+1)*datalim
if b > bodylen {
b = bodylen
}
pbs[i] = pbuf.BufferItemToBytes((*PacketItem)(
pb.copy().offset(uint16(i*datalim)),
)).Slice(a, b)
}
})
return
}
func BuildPacketFromBytes(pb PacketBytes) pbuf.Bytes {
w := bin.SelectWriter()
pb.B(func(_ []byte, p *Packet) {
w.P(func(b *pbuf.Buffer) {
p.WriteHeaderTo(&b.Buffer)
})
})
pb.V(func(b []byte) {
w.Write(b)
})
return w.ToBytes()
}

View File

@@ -1,41 +1,37 @@
package head
import (
"encoding/binary"
"fmt"
)
type PacketFlags uint16
const (
hasmorebit FlagsProto = 0x20 << iota
nofragbit
topbit //TODO: 改为 trans 标记
)
func (pf PacketFlags) String() string {
return fmt.Sprintf("%04x", uint16(pf))
const (
impossiblebit = hasmorebit | nofragbit
flagsbit = topbit | impossiblebit
protobit = ^flagsbit
)
type FlagsProto uint8
func (pf FlagsProto) String() string {
return fmt.Sprintf("%02x", uint8(pf))
}
func (pf PacketFlags) IsValid() bool {
return pf&0x8000 == 0
func (pf FlagsProto) IsValid() bool {
return pf&topbit == 0 &&
pf&impossiblebit != impossiblebit &&
pf.Proto() < ProtoTop
}
func (pf PacketFlags) DontFrag() bool {
return pf&0x4000 == 0x4000
func (pf FlagsProto) HasMore() bool {
return pf&hasmorebit != 0
}
func (pf PacketFlags) NoFrag() bool {
return pf == 0x4000
}
func (pf PacketFlags) IsSingle() bool {
return pf == 0
}
func (pf PacketFlags) ZeroOffset() bool {
return pf&0x1fff == 0
}
func (pf PacketFlags) Offset() uint16 {
return uint16(pf << 3)
}
// Flags extract flags from raw data
func Flags(data []byte) PacketFlags {
return PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
func (pf FlagsProto) NoFrag() bool {
return pf&nofragbit != 0
}

View File

@@ -1,8 +0,0 @@
package head
type Hello uint8
const (
HelloPing Hello = iota
HelloPong
)

View File

@@ -1,7 +0,0 @@
package head
// Notify 是 map[peerip]{network, endpoint}
type Notify = map[string][2]string
// Query 是 peerips 组成的数组
type Query = []string

View File

@@ -1,238 +1,78 @@
package head
import (
"encoding/binary"
"encoding/hex"
"errors"
"net"
"sync/atomic"
"unsafe"
blake2b "github.com/fumiama/blake2b-simd"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/helper"
)
const PacketHeadLen = 60
const (
// PacketHeadPreCRCIdx skip idxdatsz, which will be set at Seal().
PacketHeadPreCRCIdx = unsafe.Offsetof(Packet{}.randn)
// PacketHeadNoCRCLen without final crc
PacketHeadNoCRCLen = unsafe.Offsetof(Packet{}.md5h8rem)
PacketHeadLen = unsafe.Offsetof(Packet{}.hash)
)
var (
ErrBadCRCChecksum = errors.New("bad crc checksum")
ErrDataLenLT60 = errors.New("data len < 60")
ErrBadCRCChecksum = errors.New("bad crc checksum")
ErrDataLenLEHeader = errors.New("data len <= header len")
ErrInvalidOffset = errors.New("invalid offset")
)
type (
PacketBuf = pbuf.UserBuffer[Packet]
PacketItem = orbyte.Item[PacketBuf]
PacketBytes = pbuf.UserBytes[Packet]
)
// Packet 是发送和接收的最小单位
type Packet struct {
// idxdatsz len(Data)
// idxdatsz
//
// idx
// 高 5 位指定加密所用 key index
// 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData
//
// datsz
// 不得超过 65507-head 字节
idxdatsz uint32
// Proto 详见 head
Proto uint8
// randn
// 在发送报文时填入随机值.
randn int32
// Proto 高3位为标志(xDM)低5位为协议类型
Proto FlagsProto
// TTL is time to live
TTL uint8
// SrcPort 源端口
SrcPort uint16
// DstPort 目的端口
DstPort uint16
// Flags 高3位为标志(xDM)低13位为分片偏移
Flags PacketFlags
// 记录还有多少字节未到达
rembytes int32
// Src 源 ip (ipv4)
Src net.IP
// Dst 目的 ip (ipv4)
Dst net.IP
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
// 生成时 Hash 全 0
// Offset 分片偏移
Offset uint16
// src 源 ip (ipv4)
src [4]byte
// dst 目的 ip (ipv4)
dst [4]byte
// md5h8rem 发送时记录包头字段除自身外的 checksum 值,
// 接收时记录剩余字节数.
//
// 可以认为在一定时间内唯一 (现已更改算法为 md5 但名字未变)。
md5h8rem int64
// 以下字段为包体, 与 data 一起加密
// hash 使用 BLAKE2B 生成加密前 packet data+crc64 的摘要,
// 取其前 8 字节, 小端序读写.
//
// https://github.com/fumiama/blake2b-simd
Hash [32]byte
// crc64 包头字段的 checksum 值,可以认为在一定时间内唯一 (现已更改算法为 md5 但名字未变)
crc64 uint64
// data 承载的数据
data pbuf.Bytes
// Data 当前的偏移
a, b int
}
hash uint64
// NewPacketPartial 从一些预设参数生成一个新包
func NewPacketPartial(
proto uint8, srcPort uint16,
dst net.IP, dstPort uint16,
data pbuf.Bytes,
) *orbyte.Item[Packet] {
p := selectPacket()
pp := p.Pointer()
pp.Proto = proto
pp.TTL = 16
pp.SrcPort = srcPort
pp.DstPort = dstPort
pp.Dst = dst
pp.data = data
pp.b = data.Len()
return p
}
func ParsePacket(p Packet) *orbyte.Item[Packet] {
return packetPool.Parse(nil, p)
}
func ParsePacketHeader(data []byte) (p *orbyte.Item[Packet], err error) {
if len(data) < 60 {
err = ErrDataLenLT60
return
}
p = selectPacket()
pp := p.Pointer()
pp.crc64 = CRC64(data)
if CalcCRC64(data) != pp.crc64 {
err = ErrBadCRCChecksum
return
}
pp.idxdatsz = binary.LittleEndian.Uint32(data[:4])
sz := pp.Len()
if config.ShowDebugLog {
logrus.Debugln("[packet] header data len", sz, "read data len", len(data))
}
pt := binary.LittleEndian.Uint16(data[4:6])
pp.Proto = uint8(pt)
pp.TTL = uint8(pt >> 8)
pp.SrcPort = binary.LittleEndian.Uint16(data[6:8])
pp.DstPort = binary.LittleEndian.Uint16(data[8:10])
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
pp.Flags = flags
pp.Src = make(net.IP, 4)
copy(pp.Src, data[12:16])
pp.Dst = make(net.IP, 4)
copy(pp.Dst, data[16:20])
copy(pp.Hash[:], data[20:52])
switch {
case sz+PacketHeadLen == len(data):
pp.b = sz
pp.rembytes = -1
case pp.rembytes == 0:
pp.data = pbuf.NewBytes(sz)
pp.b = sz
pp.rembytes = int32(sz)
}
return
}
// ParseData 将 data 的数据解码到自身
//
// 必须先调用 ParsePacketHeader
func (p *Packet) ParseData(data []byte) (complete bool) {
sz := p.Len()
if sz+PacketHeadLen == len(data) {
p.data = pbuf.ParseBytes(data[PacketHeadLen:]...).Copy()
return true
}
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
if config.ShowDebugLog {
logrus.Debugln("[packet] parse data flags", flags, "off", flags.Offset())
}
if flags.ZeroOffset() {
p.Flags = flags
if config.ShowDebugLog {
logrus.Debugln("[packet] parse data set zero offset flags", flags)
}
}
rembytes := atomic.LoadInt32(&p.rembytes)
if rembytes > 0 {
n := int32(copy(p.data.Bytes()[flags.Offset():], data[PacketHeadLen:]))
newrem := rembytes - n
for !atomic.CompareAndSwapInt32(&p.rembytes, rembytes, newrem) {
rembytes = atomic.LoadInt32(&p.rembytes)
newrem = rembytes - n
}
if config.ShowDebugLog {
logrus.Debugln("[packet] copied frag", hex.EncodeToString(data[20:52]), "rembytes:", p.rembytes)
}
}
return p.rembytes <= 0
}
// DecreaseAndGetTTL TTL 自减后返回
func (p *Packet) DecreaseAndGetTTL() uint8 {
p.TTL--
return p.TTL
}
// MarshalWith 补全剩余参数, 将自身数据编码为 []byte
// offset 必须为 8 的倍数,表示偏移的 8 位
func (p *Packet) MarshalWith(
src net.IP, teatype uint8, additional uint16,
datasz uint32, offset uint16,
dontfrag, hasmore bool,
) pbuf.Bytes {
if src != nil {
p.Src = src
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
}
offset &= 0x1fff
if dontfrag {
offset |= 0x4000
}
if hasmore {
offset |= 0x2000
}
p.Flags = PacketFlags(offset)
return helper.NewWriterF(func(w *helper.Writer) {
w.WriteUInt32(p.idxdatsz)
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
w.WriteUInt16(p.SrcPort)
w.WriteUInt16(p.DstPort)
w.WriteUInt16(uint16(p.Flags))
w.Write(p.Src.To4())
w.Write(p.Dst.To4())
w.Write(p.Hash[:])
p.crc64 = CalcCRC64(w.UnsafeBytes())
w.WriteUInt64(p.crc64)
w.Write(p.UnsafeBody())
})
}
// FillHash 生成 p.Data 的 Hash
func (p *Packet) FillHash() {
h := blake2b.New256()
_, err := h.Write(p.UnsafeBody())
if err != nil {
logrus.Errorln("[packet] err when fill hash:", err)
return
}
hsh := h.Sum(p.Hash[:0])
if config.ShowDebugLog {
logrus.Debugln("[packet] sum calulated:", hex.EncodeToString(hsh))
}
}
// IsVaildHash 验证 packet 合法性
func (p *Packet) IsVaildHash() bool {
h := blake2b.New256()
_, err := h.Write(p.UnsafeBody())
if err != nil {
logrus.Errorln("[packet] err when check hash:", err)
return false
}
var sum [32]byte
_ = h.Sum(sum[:0])
if config.ShowDebugLog {
logrus.Debugln("[packet] sum data len:", len(p.UnsafeBody()))
logrus.Debugln("[packet] sum calulated:", hex.EncodeToString(sum[:]))
logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:]))
}
return sum == p.Hash
// Buffer 用于 builder with 暂存原始包体数据
// 以及接收时保存 body, 通过 PacketBytes 截取偏移.
}
// AdditionalData 获得 packet 的 additionalData
@@ -246,48 +86,19 @@ func (p *Packet) CipherIndex() uint8 {
}
// Len is packet size
func (p *Packet) Len() int {
func (p *Packet) Size() int {
return int(p.idxdatsz & 0xffff)
}
// CRC64 extract md5h8rem field
func (p *Packet) CRC64() uint64 {
return p.crc64
return uint64(p.md5h8rem)
}
// TransBody returns item.Trans().Slice()
func (p *Packet) TransBody() pbuf.Bytes {
d := p.data.Trans().Slice(p.a, p.b)
p.data = pbuf.Bytes{}
return d
func (p *Packet) Src() net.IP {
return append(net.IP{}, p.src[:]...)
}
// UnsafeBody returns data
func (p *Packet) UnsafeBody() []byte {
return p.data.Bytes()[p.a:p.b]
}
func (p *Packet) BodyLen() int {
return p.b - p.a
}
func (p *Packet) SetBody(b pbuf.Bytes) {
p.a = 0
p.b = b.Len()
p.data = b
}
func (p *Packet) CropBody(a, b int) {
if b > p.data.Len() {
b = p.data.Len()
}
if a < 0 || b < 0 || a > b {
return
}
p.a, p.b = a, b
}
func (p *Packet) ShallowCopy() (newp Packet) {
newp = *p
newp.data = p.data.Ref()
return newp
func (p *Packet) Dst() net.IP {
return append(net.IP{}, p.dst[:]...)
}

View File

@@ -1,31 +1,157 @@
package head
import (
"bytes"
crand "crypto/rand"
"encoding/hex"
"math/rand"
"net"
"runtime"
"sync"
"testing"
"github.com/fumiama/orbyte/pbuf"
"github.com/fumiama/WireGold/internal/algo"
"github.com/fumiama/WireGold/internal/bin"
)
func TestBuilderNative(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(4096)
for i := 0; i < 4096; i++ {
go func(i int) {
defer runtime.GC()
defer wg.Done()
dat := BuildPacketFromBytes(NewPacketBuilder().Proto(3).TTL(0xff).
Src(net.IPv4(1, 2, 3, 4), 5).Dst(net.IPv4(6, 7, 8, 9), 10).
With([]byte("0123456789")).Hash().Plain(0x12, 0x0345).
Split(16384, false)[0]).Trans()
s := hex.EncodeToString(dat)
if s[:8] != "12004593" {
panic("1")
}
if s[16:48] != "03ff05000a0000000102030406070809" {
panic("2")
}
if s[80:] != "30313233343536373839" {
panic("3")
}
p, err := ParsePacketHeader(dat)
if err != nil {
panic(err)
}
p.B(func(buf []byte, p *Packet) {
ok := p.WriteDataSegment(dat, buf)
if !ok {
panic(i)
}
if !algo.IsVaildBlake2bHash8(p.PreCRC64(), buf) {
panic(i)
}
if p.Proto != 3 {
panic(i)
}
if p.CipherIndex() != 0x12 {
panic(i)
}
if p.SrcPort != 5 {
panic(i)
}
if p.DstPort != 10 {
panic(i)
}
if !bytes.Equal(p.src[:], net.IPv4(1, 2, 3, 4).To4()) {
panic(i)
}
if !bytes.Equal(p.dst[:], net.IPv4(6, 7, 8, 9).To4()) {
panic(i)
}
if p.AdditionalData() != 0x0345 {
panic(i)
}
})
}(i)
}
wg.Wait()
}
func TestBuilderBE(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(4096)
bin.IsLittleEndian = false
for i := 0; i < 4096; i++ {
go func(i int) {
defer runtime.GC()
defer wg.Done()
dat := BuildPacketFromBytes(NewPacketBuilder().Proto(3).TTL(0xff).
Src(net.IPv4(1, 2, 3, 4), 5).Dst(net.IPv4(6, 7, 8, 9), 10).
With([]byte("0123456789")).Hash().Plain(0x12, 0x0345).
Split(16384, false)[0]).Trans()
s := hex.EncodeToString(dat)
if s[:8] != "12004593" {
panic("1")
}
if s[16:48] != "03ff05000a0000000102030406070809" {
panic("2")
}
if s[80:] != "30313233343536373839" {
panic("3")
}
p, err := ParsePacketHeader(dat)
if err != nil {
panic(err)
}
p.B(func(buf []byte, p *Packet) {
ok := p.WriteDataSegment(dat, buf)
if !ok {
panic(i)
}
if !algo.IsVaildBlake2bHash8(p.PreCRC64(), buf) {
panic(i)
}
if p.Proto != 3 {
panic(i)
}
if p.CipherIndex() != 0x12 {
panic(i)
}
if p.SrcPort != 5 {
panic(i)
}
if p.DstPort != 10 {
panic(i)
}
if !bytes.Equal(p.src[:], net.IPv4(1, 2, 3, 4).To4()) {
panic(i)
}
if !bytes.Equal(p.dst[:], net.IPv4(6, 7, 8, 9).To4()) {
panic(i)
}
if p.AdditionalData() != 0x0345 {
panic(i)
}
})
}(i)
}
wg.Wait()
}
func TestMarshalUnmarshal(t *testing.T) {
data := pbuf.NewBytes(4096)
n, err := crand.Read(data.Bytes())
// logrus.SetLevel(logrus.DebugLevel)
data := make([]byte, 4096)
n, err := crand.Read(data)
if n != 4096 {
t.Fatal("unexpected")
}
if err != nil {
t.Fatal(err)
}
for i := 0; i < 0x7ff; i++ {
proto := uint8(rand.Intn(255))
for i := 0; i < 4096; i++ {
proto := uint8(rand.Intn(int(ProtoTop)))
teatype := uint8(rand.Intn(32))
srcPort := uint16(rand.Intn(65535))
dstPort := uint16(rand.Intn(65535))
src := make(net.IP, 4)
_, err = crand.Read(src)
_, err := crand.Read(src)
if err != nil {
t.Fatal(err)
}
@@ -34,41 +160,47 @@ func TestMarshalUnmarshal(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p := NewPacketPartial(proto, srcPort, dst, dstPort, data.SliceTo(i))
p.Pointer().FillHash()
d := p.Pointer().MarshalWith(src, teatype, uint16(i), uint32(i), 0, true, false)
t.Log("data:", hex.EncodeToString(d.Bytes()))
p, err := ParsePacketHeader(d.Bytes())
dat := BuildPacketFromBytes(NewPacketBuilder().Proto(proto).
Src(src, srcPort).Dst(dst, dstPort).
With(data[:i]).Hash().Plain(teatype, uint16(i&0x7ff)).
Split(16384, false)[0]).Trans()
t.Log("pkt:", hex.EncodeToString(dat))
p, err := ParsePacketHeader(dat)
if err != nil {
t.Fatal("index", i, err)
}
ok := p.Pointer().ParseData(d.Bytes())
if !ok {
t.Fatal("index", i)
}
if !p.Pointer().IsVaildHash() {
t.Fatal("index", i, "expect body", hex.EncodeToString(data.SliceTo(i).Bytes()), "got", hex.EncodeToString(p.Pointer().UnsafeBody()))
}
if p.Pointer().Proto != proto {
t.Fatal("index", i)
}
if p.Pointer().CipherIndex() != teatype {
t.Fatal("index", i, "expect", teatype, "got", p.Pointer().CipherIndex())
}
if p.Pointer().SrcPort != srcPort {
t.Fatal("index", i)
}
if p.Pointer().DstPort != dstPort {
t.Fatal("index", i)
}
if !p.Pointer().Src.Equal(src) {
t.Fatal("index", i)
}
if !p.Pointer().Dst.Equal(dst) {
t.Fatal("index", i)
}
if p.Pointer().AdditionalData() != uint16(i) {
t.Fatal("index", i)
}
p.B(func(buf []byte, p *Packet) {
ok := p.WriteDataSegment(dat, buf)
if !ok {
t.Fatal("index", i)
}
if !algo.IsVaildBlake2bHash8(p.PreCRC64(), buf) {
t.Fatal("index", i, "expect body", hex.EncodeToString(data[:i]), "got", hex.EncodeToString(buf[8:]))
}
if p.Proto != FlagsProto(proto) {
t.Fatal("index", i)
}
if p.CipherIndex() != teatype {
t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex())
}
if p.SrcPort != srcPort {
t.Fatal("index", i)
}
if p.DstPort != dstPort {
t.Fatal("index", i)
}
if !bytes.Equal(p.src[:], src) {
t.Fatal("index", i)
}
if !bytes.Equal(p.dst[:], dst) {
t.Fatal("index", i)
}
if p.AdditionalData() != uint16(i&0x7ff) {
t.Fatal("index", i)
}
if !bytes.Equal(buf[8:], data[:i]) {
t.Fatal("index", i)
}
})
}
}

View File

@@ -1,37 +1,12 @@
package head
import (
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
)
type packetPooler struct {
orbyte.Pooler[Packet]
}
func (packetPooler) New(_ any, pooled Packet) Packet {
return pooled
}
func (packetPooler) Parse(obj any, _ Packet) Packet {
return obj.(Packet)
}
func (packetPooler) Reset(p *Packet) {
p.idxdatsz = 0
p.data = pbuf.Bytes{}
p.a, p.b = 0, 0
p.rembytes = 0
}
func (packetPooler) Copy(dst, src *Packet) {
*dst = *src
dst.data = src.data.Copy()
}
var packetPool = orbyte.NewPool[Packet](packetPooler{})
var packetPool = pbuf.NewBufferPool[Packet]()
// selectPacket 从池中取出一个 Packet
func selectPacket() *orbyte.Item[Packet] {
return packetPool.New(nil)
func selectPacket(buf ...byte) *PacketItem {
return (*PacketItem)(packetPool.NewBuffer(buf))
}

View File

@@ -6,4 +6,24 @@ const (
ProtoNotify
ProtoQuery
ProtoData
ProtoTrans
)
const ProtoTop = uint8(protobit + 1)
func (pf FlagsProto) Proto() uint8 {
return uint8(pf & protobit)
}
type Hello uint8
const (
HelloPing Hello = iota
HelloPong
)
// Notify 是 map[peerip]{network, endpoint}
type Notify = map[string][2]string
// Query 是 peerips 组成的数组
type Query = []string

View File

@@ -1,22 +0,0 @@
package head
import (
"crypto/md5"
"encoding/binary"
)
// CRC64 extract packet header checksum
func CRC64(data []byte) uint64 {
return binary.LittleEndian.Uint64(data[52:PacketHeadLen])
}
// CalcCRC64 calculate packet header checksum
func CalcCRC64(data []byte) uint64 {
m := md5.Sum(data[:52])
return binary.LittleEndian.Uint64(m[:8])
}
// Hash extract 32 bytes blake2b hash from raw bytes
func Hash(data []byte) []byte {
return data[20:52]
}

118
gold/head/unbox.go Normal file
View File

@@ -0,0 +1,118 @@
package head
import (
"encoding/binary"
"errors"
"sync/atomic"
"unsafe"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/internal/algo"
"github.com/fumiama/WireGold/internal/bin"
"github.com/fumiama/orbyte/pbuf"
)
func ParsePacketHeader(data []byte) (pbytes PacketBytes, err error) {
if len(data) <= int(PacketHeadLen) {
err = ErrDataLenLEHeader
return
}
p := selectPacket()
sz := 0
p.P(func(pb *PacketBuf) {
if bin.IsLittleEndian {
copy((*[PacketHeadLen]byte)(
(unsafe.Pointer)(&pb.DAT),
)[:], data)
} else {
pb.DAT.idxdatsz = binary.LittleEndian.Uint32(data[:4])
pb.DAT.randn = int32(binary.LittleEndian.Uint32(data[4:8]))
pt := binary.LittleEndian.Uint16(data[8:10])
pb.DAT.Proto = FlagsProto(pt)
pb.DAT.TTL = uint8(pt >> 8)
pb.DAT.SrcPort = binary.LittleEndian.Uint16(data[10:12])
pb.DAT.DstPort = binary.LittleEndian.Uint16(data[12:14])
pb.DAT.Offset = binary.LittleEndian.Uint16(data[14:16])
copy(pb.DAT.src[:], data[16:20])
copy(pb.DAT.dst[:], data[20:24])
pb.DAT.md5h8rem = int64(binary.LittleEndian.Uint64(data[24:32]))
}
sz = pb.DAT.Size()
if !pb.DAT.Proto.IsValid() {
err = errors.New("invalid proto " + pb.DAT.Proto.String())
return
}
if (!pb.DAT.Proto.HasMore() && (pb.DAT.Offset != 0 ||
sz+int(PacketHeadLen) != len(data))) ||
(pb.DAT.Proto.HasMore() && pb.DAT.Offset+
uint16(len(data[PacketHeadLen:])) > uint16(sz)) {
err = ErrInvalidOffset
if config.ShowDebugLog {
logrus.Warnf("[unbox] invalid offset %04x size %04x", pb.DAT.Offset, sz)
}
return
}
crc := algo.MD5Hash8(data[:PacketHeadNoCRCLen])
if crc != uint64(pb.DAT.md5h8rem) {
err = ErrBadCRCChecksum
if config.ShowDebugLog {
logrus.Warnf("[unbox] exp crc %016x but got %016x", pb.DAT.md5h8rem, crc)
}
return
}
if config.ShowDebugLog {
logrus.Debugln("[unbox] header data len", sz, "read data len", len(data)-int(PacketHeadLen))
}
if sz+int(PacketHeadLen) == len(data) {
pb.Buffer.Write(data[PacketHeadLen:])
pb.DAT.md5h8rem = -1
return
}
pb.Buffer.Grow(sz)
pb.Buffer.Write(make([]byte, sz))
pb.DAT.md5h8rem = int64(sz)
})
if err != nil {
return
}
pbytes = pbuf.BufferItemToBytes(p)
return
}
// WriteDataSegment 将 data 的数据并发解码到自身 buf.
//
// 必须先调用 ParsePacketHeader 获得 packet.
//
// return: complete.
func (p *Packet) WriteDataSegment(data, buf []byte) bool {
if atomic.LoadInt64(&p.md5h8rem) <= 0 {
return true
}
flags := FlagsProto(data[8])
offset := binary.LittleEndian.Uint16(data[14:16])
if config.ShowDebugLog {
logrus.Debugln("[unbox] parse data flags", flags, "off", offset)
}
if offset == 0 {
p.Proto = flags
p.Offset = 0
if config.ShowDebugLog {
logrus.Debugln("[unbox] parse data set zero offset flags", flags)
}
}
rembytes := atomic.LoadInt64(&p.md5h8rem)
if rembytes > 0 {
n := int64(copy(buf[offset:], data[PacketHeadLen:]))
newrem := rembytes - n
for !atomic.CompareAndSwapInt64(&p.md5h8rem, rembytes, newrem) {
rembytes = atomic.LoadInt64(&p.md5h8rem)
newrem = rembytes - n
}
}
return atomic.LoadInt64(&p.md5h8rem) <= 0
}