diff --git a/config/global.go b/config/global.go index df62328..4bcfe7d 100644 --- a/config/global.go +++ b/config/global.go @@ -1,3 +1,3 @@ package config -const ShowDebugLog = true +const ShowDebugLog = false diff --git a/go.mod b/go.mod index 503cfb1..ad99a3d 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,12 @@ module github.com/fumiama/WireGold go 1.20 require ( - github.com/FloatTech/ttl v0.0.0-20240716161252-965925764562 + github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb github.com/fumiama/go-base16384 v1.7.0 github.com/fumiama/go-x25519 v1.0.0 + github.com/fumiama/orbyte v0.0.0-20250225103543-4a462a143731 github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac github.com/klauspost/compress v1.17.9 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index 16034d4..7a85165 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/FloatTech/ttl v0.0.0-20240716161252-965925764562 h1:snfw7FNFym1eNnLrQ/VCf80LiQo9C7jHgrunZDwiRcY= -github.com/FloatTech/ttl v0.0.0-20240716161252-965925764562/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= +github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d h1:mUQ/c3wXKsUGa4Sg9DBy01APXKB68PmobhxOyaJI7lY= +github.com/FloatTech/ttl v0.0.0-20250224045156-012b1463287d/go.mod h1:fHZFWGquNXuHttu9dUYoKuNbm3dzLETnIOnm1muSfDs= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7 h1:S/ferNiehVjNaBMNNBxUjLtVmP/YWD6Yh79RfPv4ehU= github.com/RomiChan/syncx v0.0.0-20240418144900-b7402ffdebc7/go.mod h1:vD7Ra3Q9onRtojoY5sMCLQ7JBgjUsrXDnDKyFxqpf9w= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -11,6 +11,8 @@ github.com/fumiama/go-base16384 v1.7.0 h1:6fep7XPQWxRlh4Hu+KsdH+6+YdUp+w6CwRXtMW github.com/fumiama/go-base16384 v1.7.0/go.mod h1:OEn+947GV5gsbTAnyuUW/SrfxJYUdYupSIQXOuGOcXM= github.com/fumiama/go-x25519 v1.0.0 h1:hiGg9EhseVmGCc8T1jECVkj8Keu/aJ1ZK05RM8Vuavo= github.com/fumiama/go-x25519 v1.0.0/go.mod h1:8VOhfyGZzw4IUs4nCjQFqW9cA3V/QpSCtP3fo2dLNg4= +github.com/fumiama/orbyte v0.0.0-20250225103543-4a462a143731 h1:FUL+OQCPz69Iizhrf2GxRvcGSbJgAQvV/pIAFF4g8OQ= +github.com/fumiama/orbyte v0.0.0-20250225103543-4a462a143731/go.mod h1:qkUllQ1+gTx5sGrmKvIsqUgsnOO21Hiq847YHJRifbk= github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac h1:A/5A0rODsg+EQHH61Ew5mMUtDpRXaSNqHhPvW+fN4C4= github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac/go.mod h1:BBnNY9PwK+UUn4trAU+H0qsMEypm7+3Bj1bVFuJItlo= github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0 h1:WfrSFlIlCAtg6Rt2IGna0HhJYSDE45YVHiYqO4wwsEw= diff --git a/gold/head/flags.go b/gold/head/flags.go index de3cc56..7dfeb17 100644 --- a/gold/head/flags.go +++ b/gold/head/flags.go @@ -1,9 +1,16 @@ package head -import "encoding/binary" +import ( + "encoding/binary" + "fmt" +) type PacketFlags uint16 +func (pf PacketFlags) String() string { + return fmt.Sprintf("%04x", uint16(pf)) +} + func (pf PacketFlags) IsValid() bool { return pf&0x8000 == 0 } diff --git a/gold/head/packet.go b/gold/head/packet.go index 2b89385..3cb9ee2 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -5,8 +5,11 @@ import ( "encoding/hex" "errors" "net" + "sync/atomic" blake2b "github.com/fumiama/blake2b-simd" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" "github.com/sirupsen/logrus" "github.com/fumiama/WireGold/config" @@ -37,6 +40,8 @@ type Packet struct { DstPort uint16 // Flags 高3位为标志(xDM),低13位为分片偏移 Flags PacketFlags + // 记录还有多少字节未到达 + rembytes int32 // Src 源 ip (ipv4) Src net.IP // Dst 目的 ip (ipv4) @@ -48,82 +53,113 @@ type Packet struct { // crc64 包头字段的 checksum 值,可以认为在一定时间内唯一 (现已更改算法为 md5 但名字未变) crc64 uint64 // data 承载的数据 - data []byte + data pbuf.Bytes // Data 当前的偏移 a, b int - // 记录还有多少字节未到达 - rembytes int - // 是否经由 helper.MakeBytes 创建 Data - buffered bool } -// NewPacket 生成一个新包 -func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) (p *Packet) { - p = SelectPacket() - p.Proto = proto - p.TTL = 16 - p.SrcPort = srcPort - p.DstPort = dstPort - p.Dst = dst - p.data = data - p.b = len(data) +// NewPacketPartial 从一些预设参数生成一个新包 +func NewPacketPartial( + proto uint8, srcPort uint16, + dst net.IP, dstPort uint16, + data pbuf.Bytes, +) (p *orbyte.Item[Packet]) { + p = selectPacket() + pp := p.Pointer() + pp.Proto = proto + pp.TTL = 16 + pp.SrcPort = srcPort + pp.DstPort = dstPort + pp.Dst = dst + pp.data = data + pp.b = data.Len() return } -// Unmarshal 将 data 的数据解码到自身 -func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { +func ParsePacket(p Packet) *orbyte.Item[Packet] { + return packetPool.Parse(nil, p) +} + +func ParsePacketHeader(data []byte) (p *orbyte.Item[Packet], err error) { if len(data) < 60 { err = ErrDataLenLT60 return } - p.crc64 = CRC64(data) - if CalcCRC64(data) != p.crc64 { + p = selectPacket() + pp := p.Pointer() + pp.crc64 = CRC64(data) + if CalcCRC64(data) != pp.crc64 { err = ErrBadCRCChecksum return } + pp.idxdatsz = binary.LittleEndian.Uint32(data[:4]) + sz := pp.Len() + if config.ShowDebugLog { + logrus.Debugln("[packet] header data len", sz, "read data len", len(data)) + } + pt := binary.LittleEndian.Uint16(data[4:6]) + pp.Proto = uint8(pt) + pp.TTL = uint8(pt >> 8) + pp.SrcPort = binary.LittleEndian.Uint16(data[6:8]) + pp.DstPort = binary.LittleEndian.Uint16(data[8:10]) + + flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12])) + pp.Flags = flags + pp.Src = make(net.IP, 4) + copy(pp.Src, data[12:16]) + pp.Dst = make(net.IP, 4) + copy(pp.Dst, data[16:20]) + copy(pp.Hash[:], data[20:52]) + + switch { + case sz+PacketHeadLen == len(data): + pp.b = sz + pp.rembytes = -1 + case pp.rembytes == 0: + pp.data = pbuf.NewBytes(sz) + pp.b = sz + pp.rembytes = int32(sz) + } + + return +} + +// ParseData 将 data 的数据解码到自身 +// +// 必须先调用 ParsePacketHeader +func (p *Packet) ParseData(data []byte) (complete bool) { sz := p.Len() - if sz == 0 && len(p.data) == 0 { - p.idxdatsz = binary.LittleEndian.Uint32(data[:4]) - sz = p.Len() - if sz+52 == len(data) { - p.data = data[52:] - p.b = len(p.data) - p.rembytes = 0 - } else { - p.data = helper.MakeBytes(sz) - p.buffered = true - p.b = sz - p.rembytes = sz - } - pt := binary.LittleEndian.Uint16(data[4:6]) - p.Proto = uint8(pt) - p.TTL = uint8(pt >> 8) - p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) - p.DstPort = binary.LittleEndian.Uint16(data[8:10]) + if sz+PacketHeadLen == len(data) { + p.data = pbuf.ParseBytes(data[PacketHeadLen:]...) + return true } flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12])) - + if config.ShowDebugLog { + logrus.Debugln("[packet] parse data flags", flags, "off", flags.Offset()) + } if flags.ZeroOffset() { p.Flags = flags - p.Src = make(net.IP, 4) - copy(p.Src, data[12:16]) - p.Dst = make(net.IP, 4) - copy(p.Dst, data[16:20]) - copy(p.Hash[:], data[20:52]) - } - - if p.rembytes > 0 { - p.rembytes -= copy(p.data[flags.Offset():], data[PacketHeadLen:]) if config.ShowDebugLog { - logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) + logrus.Debugln("[packet] parse data set zero offset flags", flags) } } - complete = p.rembytes == 0 + rembytes := atomic.LoadInt32(&p.rembytes) + if rembytes > 0 { + n := int32(copy(p.data.Bytes()[flags.Offset():], data[PacketHeadLen:])) + newrem := rembytes - n + for !atomic.CompareAndSwapInt32(&p.rembytes, rembytes, newrem) { + rembytes = atomic.LoadInt32(&p.rembytes) + newrem = rembytes - n + } + if config.ShowDebugLog { + logrus.Debugln("[packet] copied frag", hex.EncodeToString(data[20:52]), "rembytes:", p.rembytes) + } + } - return + return p.rembytes <= 0 } // DecreaseAndGetTTL TTL 自减后返回 @@ -132,9 +168,13 @@ func (p *Packet) DecreaseAndGetTTL() uint8 { return p.TTL } -// Marshal 将自身数据编码为 []byte +// MarshalWith 补全剩余参数, 将自身数据编码为 []byte // offset 必须为 8 的倍数,表示偏移的 8 位 -func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) { +func (p *Packet) MarshalWith( + src net.IP, teatype uint8, additional uint16, + datasz uint32, offset uint16, + dontfrag, hasmore bool, +) pbuf.Bytes { if src != nil { p.Src = src p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff @@ -148,8 +188,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui offset |= 0x2000 } p.Flags = PacketFlags(offset) - - return helper.OpenWriterF(func(w *helper.Writer) { + return helper.NewWriterF(func(w *helper.Writer) { w.WriteUInt32(p.idxdatsz) w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto)) w.WriteUInt16(p.SrcPort) @@ -158,7 +197,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui w.Write(p.Src.To4()) w.Write(p.Dst.To4()) w.Write(p.Hash[:]) - p.crc64 = CalcCRC64(w.Bytes()) + p.crc64 = CalcCRC64(w.UnsafeBytes()) w.WriteUInt64(p.crc64) w.Write(p.Body()) }) @@ -189,6 +228,7 @@ func (p *Packet) IsVaildHash() bool { var sum [32]byte _ = h.Sum(sum[:0]) if config.ShowDebugLog { + logrus.Debugln("[packet] sum data len:", len(p.Body())) logrus.Debugln("[packet] sum calulated:", hex.EncodeToString(sum[:])) logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:])) } @@ -210,45 +250,28 @@ func (p *Packet) Len() int { return int(p.idxdatsz & 0xffff) } -// Put 将自己放回池中 -func (p *Packet) Put() { - PutPacket(p) -} - func (p *Packet) CRC64() uint64 { return p.crc64 } // Body returns data func (p *Packet) Body() []byte { - return p.data[p.a:p.b] + return p.data.Bytes()[p.a:p.b] } func (p *Packet) BodyLen() int { return p.b - p.a } -func (p *Packet) SetBody(b []byte, buffered bool) { +func (p *Packet) SetBody(b []byte) { p.a = 0 p.b = len(b) - if len(b) <= cap(p.data) { - p.data = p.data[:len(b)] - copy(p.data, b) - if buffered { - helper.PutBytes(b) - } - return - } - if p.buffered { - helper.PutBytes(p.data) - } - p.data = b - p.buffered = buffered + p.data = pbuf.ParseBytes(b...) } func (p *Packet) CropBody(a, b int) { - if b > len(p.data) { - b = len(p.data) + if b > p.data.Len() { + b = p.data.Len() } if a < 0 || b < 0 || a > b { return @@ -256,16 +279,8 @@ func (p *Packet) CropBody(a, b int) { p.a, p.b = a, b } -func (p *Packet) Copy() *Packet { - newp := SelectPacket() - *newp = *p - newp.buffered = false - return newp -} - -func (p *Packet) CopyWithBody() *Packet { - newp := p.Copy() - newp.data = helper.MakeBytes(len(p.data)) - copy(newp.data, p.data) +func (p *Packet) ShallowCopy() (newp Packet) { + newp = *p + newp.data = p.data.Ref() return newp } diff --git a/gold/head/packet_test.go b/gold/head/packet_test.go index b19de4c..581eda4 100644 --- a/gold/head/packet_test.go +++ b/gold/head/packet_test.go @@ -2,14 +2,20 @@ package head import ( crand "crypto/rand" + "encoding/hex" "math/rand" "net" "testing" + + "github.com/fumiama/orbyte/pbuf" ) func TestMarshalUnmarshal(t *testing.T) { - data := make([]byte, 4096) - _, err := crand.Read(data) + data := pbuf.NewBytes(4096) + n, err := crand.Read(data.Bytes()) + if n != 4096 { + t.Fatal("unexpected") + } if err != nil { t.Fatal(err) } @@ -28,40 +34,40 @@ func TestMarshalUnmarshal(t *testing.T) { if err != nil { t.Fatal(err) } - p := NewPacket(proto, srcPort, dst, dstPort, data) - p.FillHash() - d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false) - p = SelectPacket() - ok, err := p.Unmarshal(d) - cl() + p := NewPacketPartial(proto, srcPort, dst, dstPort, data.SliceTo(i)) + p.Pointer().FillHash() + d := p.Pointer().MarshalWith(src, teatype, uint16(i), uint32(i), 0, true, false) + t.Log("data:", hex.EncodeToString(d.Bytes())) + p, err := ParsePacketHeader(d.Bytes()) + if err != nil { + t.Fatal("index", i, err) + } + ok := p.Pointer().ParseData(d.Bytes()) if !ok { t.Fatal("index", i) } - if err != nil { - t.Fatal(err) + if !p.Pointer().IsVaildHash() { + t.Fatal("index", i, "expect body", hex.EncodeToString(data.SliceTo(i).Bytes()), "got", hex.EncodeToString(p.Pointer().Body())) } - if !p.IsVaildHash() { + if p.Pointer().Proto != proto { t.Fatal("index", i) } - if p.Proto != proto { + if p.Pointer().CipherIndex() != teatype { + t.Fatal("index", i, "expect", teatype, "got", p.Pointer().CipherIndex()) + } + if p.Pointer().SrcPort != srcPort { t.Fatal("index", i) } - if p.CipherIndex() != teatype { - t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex()) - } - if p.SrcPort != srcPort { + if p.Pointer().DstPort != dstPort { t.Fatal("index", i) } - if p.DstPort != dstPort { + if !p.Pointer().Src.Equal(src) { t.Fatal("index", i) } - if !p.Src.Equal(src) { + if !p.Pointer().Dst.Equal(dst) { t.Fatal("index", i) } - if !p.Dst.Equal(dst) { - t.Fatal("index", i) - } - if p.AdditionalData() != uint16(i) { + if p.Pointer().AdditionalData() != uint16(i) { t.Fatal("index", i) } } diff --git a/gold/head/pool.go b/gold/head/pool.go index 4635136..60233bd 100644 --- a/gold/head/pool.go +++ b/gold/head/pool.go @@ -1,31 +1,37 @@ package head import ( - "sync" - - "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" ) -var packetPool = sync.Pool{ - New: func() interface{} { - return new(Packet) - }, +type packetPooler struct { + orbyte.Pooler[Packet] } -// SelectPacket 从池中取出一个 Packet -func SelectPacket() *Packet { - return packetPool.Get().(*Packet) +func (packetPooler) New(_ any, pooled Packet) Packet { + return pooled } -// PutPacket 将 Packet 放回池中 -func PutPacket(p *Packet) { +func (packetPooler) Parse(obj any, _ Packet) Packet { + return obj.(Packet) +} + +func (packetPooler) Reset(p *Packet) { p.idxdatsz = 0 - if p.buffered { - helper.PutBytes(p.data) - p.buffered = false - } + p.data = pbuf.Bytes{} p.a, p.b = 0, 0 - p.data = nil p.rembytes = 0 - packetPool.Put(p) +} + +func (packetPooler) Copy(dst, src *Packet) { + *dst = *src + dst.data = src.data.Copy() +} + +var packetPool = orbyte.NewPool[Packet](packetPooler{}) + +// selectPacket 从池中取出一个 Packet +func selectPacket() *orbyte.Item[Packet] { + return packetPool.New(nil) } diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 6f677da..6f92472 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -7,8 +7,9 @@ import ( "errors" "math/bits" mrand "math/rand" + "runtime" - "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte/pbuf" "github.com/sirupsen/logrus" ) @@ -53,14 +54,12 @@ func expandkeyunit(v1, v2 byte) (v uint16) { } // Encode by aead and put b into pool -func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { +func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb pbuf.Bytes) { if len(b) == 0 || teatype >= 32 { return } if l.keys[0] == nil { - eb = helper.MakeBytes(len(b)) - copy(eb, b) - return + return pbuf.ParseBytes(b...) } aead := l.keys[teatype] if aead == nil { @@ -72,14 +71,12 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { } // Decode by aead and put b into pool -func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) { +func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db pbuf.Bytes, err error) { if len(b) == 0 || teatype >= 32 { return } if l.keys[0] == nil { - db = helper.MakeBytes(len(b)) - copy(db, b) - return + return pbuf.ParseBytes(b...), nil } aead := l.keys[teatype] if aead == nil { @@ -88,59 +85,67 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, er return decode(aead, additional, b) } -func encode(aead cipher.AEAD, additional uint16, b []byte) []byte { +func encode(aead cipher.AEAD, additional uint16, b []byte) pbuf.Bytes { nsz := aead.NonceSize() // Accocate capacity for all the stuffs. - buf := helper.MakeBytes(2 + nsz + len(b) + aead.Overhead()) - binary.LittleEndian.PutUint16(buf[:2], additional) - nonce := buf[2 : 2+nsz] + buf := pbuf.NewBytes(2 + nsz + len(b) + aead.Overhead()) + binary.LittleEndian.PutUint16(buf.Bytes()[:2], additional) + nonce := buf.Bytes()[2 : 2+nsz] // Select a random nonce _, err := rand.Read(nonce) if err != nil { panic(err) } // Encrypt the message and append the ciphertext to the nonce. - eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf[:2]) - return nonce[:nsz+len(eb)] + eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf.Bytes()[:2]) + return buf.Trans().Slice(2, 2+nsz+len(eb)) } -func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { +func decode(aead cipher.AEAD, additional uint16, b []byte) (pbuf.Bytes, error) { nsz := aead.NonceSize() if len(b) < nsz { - return nil, ErrCipherTextTooShort + return pbuf.Bytes{}, ErrCipherTextTooShort } // Split nonce and ciphertext. nonce, ciphertext := b[:nsz], b[nsz:] if len(ciphertext) == 0 { - return nil, nil + return pbuf.Bytes{}, nil } // Decrypt the message and check it wasn't tampered with. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) - return aead.Open(helper.SelectWriter().Bytes(), nonce, ciphertext, buf[:]) + data, err := aead.Open( + pbuf.NewBytes(4096).Trans().Bytes()[:0], + nonce, ciphertext, buf[:], + ) + if err != nil { + return pbuf.Bytes{}, nil + } + return pbuf.ParseBytes(data...), nil } // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data -func (m *Me) xorenc(data []byte, seq uint32) []byte { +func (m *Me) xorenc(data []byte, seq uint32) pbuf.Bytes { batchsz := len(data) / 8 remain := len(data) % 8 sum := m.mask - newdat := helper.MakeBytes(8 + batchsz*8 + 8) // seqrand dat tail - binary.LittleEndian.PutUint32(newdat[:4], seq) - _, _ = rand.Read(newdat[4:8]) // seqrand - sum ^= binary.LittleEndian.Uint64(newdat[:8]) // init from seqrand - binary.LittleEndian.PutUint64(newdat[:8], sum) + newdat := pbuf.NewBytes(8 + batchsz*8 + 8) // seqrand dat tail + binary.LittleEndian.PutUint32(newdat.Bytes()[:4], seq) + _, _ = rand.Read(newdat.Bytes()[4:8]) // seqrand + sum ^= binary.LittleEndian.Uint64(newdat.Bytes()[:8]) // init from seqrand + binary.LittleEndian.PutUint64(newdat.Bytes()[:8], sum) for i := 0; i < batchsz; i++ { // range on batch data a := i * 8 b := (i + 1) * 8 sum ^= binary.LittleEndian.Uint64(data[a:b]) - binary.LittleEndian.PutUint64(newdat[a+8:b+8], sum) + binary.LittleEndian.PutUint64(newdat.Bytes()[a+8:b+8], sum) } p := batchsz * 8 - copy(newdat[8+p:], data[p:]) - newdat[len(newdat)-1] = byte(remain) - sum ^= binary.LittleEndian.Uint64(newdat[8+p:]) - binary.LittleEndian.PutUint64(newdat[8+p:], sum) + copy(newdat.Bytes()[8+p:], data[p:]) + runtime.KeepAlive(data) + newdat.Bytes()[newdat.Len()-1] = byte(remain) + sum ^= binary.LittleEndian.Uint64(newdat.Bytes()[8+p:]) + binary.LittleEndian.PutUint64(newdat.Bytes()[8+p:], sum) return newdat } @@ -163,5 +168,6 @@ func (m *Me) xordec(data []byte) (uint32, []byte) { if remain >= 8 { return 0, nil } - return binary.LittleEndian.Uint32(data[:4]), data[8 : len(data)-8+int(remain)] + return binary.LittleEndian.Uint32(data[:4]), + data[8 : len(data)-8+int(remain)] } diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index cc85732..4285e11 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/hex" "io" + "runtime" "testing" "golang.org/x/crypto/chacha20poly1305" @@ -27,10 +28,11 @@ func TestXOR(t *testing.T) { if err != nil { t.Fatal(err) } - seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i))) + seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i)).Trans().Bytes()) if !bytes.Equal(dec, r2.Bytes()) { t.Fatal("unexpected xor at", i, "except", hex.EncodeToString(r2.Bytes()), "got", hex.EncodeToString(dec)) } + runtime.KeepAlive(dec) if seq != uint32(i) { t.Fatal("unexpected xor at", i, "seq", seq) } @@ -53,11 +55,11 @@ func TestXChacha20(t *testing.T) { t.Fatal(err) } for i := 0; i < 4096; i++ { - db, err := decode(aead, uint16(i), encode(aead, uint16(i), data[:i])) + db, err := decode(aead, uint16(i), encode(aead, uint16(i), data[:i]).Trans().Bytes()) if err != nil { t.Fatal(err) } - if !bytes.Equal(db, data[:i]) { + if !bytes.Equal(db.Bytes(), data[:i]) { t.Fatal("unexpected preshared at idx(len)", i, "addt", uint16(i)) } } diff --git a/gold/link/link.go b/gold/link/link.go index 8e76748..8e8a7d0 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -11,6 +11,7 @@ import ( "github.com/fumiama/WireGold/gold/p2p" "github.com/fumiama/WireGold/helper" base14 "github.com/fumiama/go-base16384" + "github.com/fumiama/orbyte" ) var ( @@ -26,7 +27,7 @@ type Link struct { // 收到的包的队列 // 没有下层 nic 时 // 包会分发到此 - pipe chan *head.Packet + pipe chan *orbyte.Item[head.Packet] // peer 的虚拟 ip peerip net.IP // peer 的公网 endpoint diff --git a/gold/link/listen.go b/gold/link/listen.go index ea40365..bd1839d 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -19,56 +19,12 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/p2p" "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" ) const lstnbufgragsz = 65536 -type lstnq struct { - index int - addr p2p.EndPoint - buf []byte -} - -type listenqueue chan lstnq - -func (q listenqueue) listen(m *Me, hasntfinished []sync.Mutex) { - recvtotlcnt := uint64(0) - recvloopcnt := uint16(0) - recvlooptime := time.Now().UnixMilli() - for lstn := range q { - recvtotlcnt += uint64(len(lstn.buf)) - recvloopcnt++ - if recvloopcnt%m.speedloop == 0 { - now := time.Now().UnixMilli() - logrus.Infof("[listen] queue recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime)) - recvtotlcnt = 0 - recvlooptime = now - } - packet := m.wait(lstn.buf[:len(lstn.buf):lstnbufgragsz]) - if packet == nil { - if lstn.index < 0 { - if config.ShowDebugLog { - logrus.Debugln("[listen] queue waiting") - } - helper.PutBytes(lstn.buf) - continue - } - if config.ShowDebugLog { - logrus.Debugln("[listen] queue waiting, unlock index", lstn.index) - } - hasntfinished[lstn.index].Unlock() - continue - } - if lstn.index >= 0 { - go m.dispatch(packet, lstn.addr, lstn.index, hasntfinished[lstn.index].Unlock) - } else { - go m.dispatch(packet, lstn.addr, lstn.index, func() { - helper.PutBytes(lstn.buf) - }) - } - } -} - // 监听本机 endpoint func (m *Me) listen() (conn p2p.Conn, err error) { conn, err = m.ep.Listen() @@ -85,9 +41,6 @@ func (m *Me) listen() (conn p2p.Conn, err error) { logrus.Infoln("[listen] use cpu num:", n) listenbuf := make([]byte, lstnbufgragsz*n) hasntfinished := make([]sync.Mutex, n) - q := make(listenqueue, n) - defer close(q) - go q.listen(m, hasntfinished) for { usenewbuf := false i := uint(0) @@ -105,13 +58,16 @@ func (m *Me) listen() (conn p2p.Conn, err error) { if config.ShowDebugLog && !usenewbuf { logrus.Debugln("[listen] lock index", i) } - var lbf []byte + var lbf pbuf.Bytes if usenewbuf { - lbf = helper.MakeBytes(lstnbufgragsz) + lbf = pbuf.NewBytes(lstnbufgragsz) } else { - lbf = listenbuf[i*lstnbufgragsz : (i+1)*lstnbufgragsz] + if config.ShowDebugLog { + logrus.Debugln("[listen] take index", i, "slice", i*lstnbufgragsz, (i+1)*lstnbufgragsz, "cap", lstnbufgragsz) + } + lbf = pbuf.ParseBytes(listenbuf[i*lstnbufgragsz : (i+1)*lstnbufgragsz : (i+1)*lstnbufgragsz]...) } - n, addr, err := conn.ReadFromPeer(lbf) + n, addr, err := conn.ReadFromPeer(lbf.Bytes()) if m.connections == nil || errors.Is(err, net.ErrClosed) { logrus.Warnln("[listen] quit listening") return @@ -138,39 +94,70 @@ func (m *Me) listen() (conn p2p.Conn, err error) { } continue } - lq := lstnq{ - index: -1, - addr: addr, - buf: lbf[:n], - } + index := -1 if !usenewbuf { - lq.index = int(i) + index = int(i) } - q <- lq + go m.waitordispatch(index, addr, lbf.Trans().SliceTo(n), hasntfinished) } }() return } -func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish func()) { - defer finish() +func (m *Me) waitordispatch(index int, addr p2p.EndPoint, buf pbuf.Bytes, hasntfinished []sync.Mutex) { + recvtotlcnt := atomic.AddUint64(&m.recvtotlcnt, uint64(buf.Len())) + recvloopcnt := atomic.AddUintptr(&m.recvloopcnt, 1) + recvlooptime := atomic.LoadInt64(&m.recvlooptime) + if recvloopcnt%uintptr(m.speedloop) == 0 { + now := time.Now().UnixMilli() + logrus.Infof("[listen] queue recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime)) + atomic.StoreUint64(&m.recvtotlcnt, 0) + atomic.StoreInt64(&m.recvlooptime, now) + } + packet := m.wait(buf.Trans()) + if packet == nil { + if index < 0 { + if config.ShowDebugLog { + logrus.Debugln("[listen] queue waiting") + } + return + } + if config.ShowDebugLog { + logrus.Debugln("[listen] queue waiting, unlock index", index) + } + hasntfinished[index].Unlock() + return + } + if config.ShowDebugLog { + logrus.Debugln("[listen] index", index, "dispatch", len(packet.Pointer().Body()), "bytes packet") + } + if index >= 0 { + defer hasntfinished[index].Unlock() + m.dispatch(packet, addr, index) + return + } + m.dispatch(packet, addr, index) +} + +func (m *Me) dispatch(packet *orbyte.Item[head.Packet], addr p2p.EndPoint, index int) { + defer runtime.KeepAlive(packet) + if config.ShowDebugLog { defer logrus.Debugln("[listen] dispatched, unlock index", index) logrus.Debugln("[listen] start dispatching index", index) } - r := packet.Len() - packet.BodyLen() + pp := packet.Pointer() + r := pp.Len() - pp.BodyLen() if r > 0 { - logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "len", packet.BodyLen(), "is smaller than it declared len", packet.Len(), ", drop it") - packet.Put() + logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "len", pp.BodyLen(), "is smaller than it declared len", pp.Len(), ", drop it") return } - p, ok := m.IsInPeer(packet.Src.String()) + p, ok := m.IsInPeer(pp.Src.String()) if config.ShowDebugLog { - logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) + logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", pp.Src, "dst", pp.Dst) } if !ok { - logrus.Warnln("[listen] @", index, "packet from", packet.Src, "to", packet.Dst, "is refused") - packet.Put() + logrus.Warnln("[listen] @", index, "packet from", pp.Src, "to", pp.Dst, "is refused") return } if helper.IsNilInterface(p.endpoint) || !p.endpoint.Euqal(addr) { @@ -185,25 +172,23 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish now := time.Now() atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.lastalive)), unsafe.Pointer(&now)) switch { - case p.IsToMe(packet.Dst): - if !p.Accept(packet.Src) { - logrus.Warnln("[listen] @", index, "refused packet from", packet.Src.String()+":"+strconv.Itoa(int(packet.SrcPort))) - packet.Put() + case p.IsToMe(pp.Dst): + if !p.Accept(pp.Src) { + logrus.Warnln("[listen] @", index, "refused packet from", pp.Src.String()+":"+strconv.Itoa(int(pp.SrcPort))) return } - addt := packet.AdditionalData() + addt := pp.AdditionalData() var err error - data, err := p.Decode(packet.CipherIndex(), addt, packet.Body()) + data, err := p.Decode(pp.CipherIndex(), addt, pp.Body()) if err != nil { if config.ShowDebugLog { - logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", packet.CipherIndex(), "addt:", addt, "err:", err) + logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", pp.CipherIndex(), "addt:", addt, "err:", err) } - packet.Put() return } - packet.SetBody(data, true) + pp.SetBody(data.Trans().Bytes()) if p.usezstd { - dec, _ := zstd.NewReader(bytes.NewReader(packet.Body())) + dec, _ := zstd.NewReader(bytes.NewReader(pp.Body())) var err error w := helper.SelectWriter() _, err = io.Copy(w, dec) @@ -212,25 +197,27 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish if config.ShowDebugLog { logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err) } - packet.Put() return } - packet.SetBody(w.Bytes(), true) + if config.ShowDebugLog { + logrus.Debugln("[listen] @", index, "zstd decoded len:", w.Len()) + } + pp.SetBody(w.TransBytes()) } - if !packet.IsVaildHash() { + if !pp.IsVaildHash() { if config.ShowDebugLog { logrus.Debugln("[listen] @", index, "drop invalid hash packet") } - packet.Put() return } - switch packet.Proto { + switch pp.Proto { case head.ProtoHello: switch { - case len(packet.Body()) == 0: + case len(pp.Body()) == 0: logrus.Warnln("[listen] @", index, "recv old hello packet, do nothing") - case packet.Body()[0] == byte(head.HelloPing): - n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), []byte{byte(head.HelloPong)}), false) + case pp.Body()[0] == byte(head.HelloPing): + n, err := p.WritePacket(head.NewPacketPartial( + head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), pbuf.ParseBytes(byte(head.HelloPong))), false) if err == nil { logrus.Infoln("[listen] @", index, "recv hello, send", n, "bytes hello ack packet") } else { @@ -239,57 +226,49 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish default: logrus.Infoln("[listen] @", index, "recv hello ack packet, do nothing") } - packet.Put() case head.ProtoNotify: - logrus.Infoln("[listen] @", index, "recv notify from", packet.Src) - go p.onNotify(packet.Body()) - packet.Put() + logrus.Infoln("[listen] @", index, "recv notify from", pp.Src) + p.onNotify(pp.Body()) case head.ProtoQuery: - logrus.Infoln("[listen] @", index, "recv query from", packet.Src) - go p.onQuery(packet.Body()) - packet.Put() + logrus.Infoln("[listen] @", index, "recv query from", pp.Src) + p.onQuery(pp.Body()) case head.ProtoData: if p.pipe != nil { - p.pipe <- packet.CopyWithBody() + p.pipe <- packet.Copy() if config.ShowDebugLog { logrus.Debugln("[listen] @", index, "deliver to pipe of", p.peerip) } } else { - _, err := m.nic.Write(packet.Body()) + _, err := m.nic.Write(pp.Body()) if err != nil { - logrus.Errorln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic err:", err) + logrus.Errorln("[listen] @", index, "deliver", pp.BodyLen(), "bytes data to nic err:", err) } else if config.ShowDebugLog { - logrus.Debugln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic") + logrus.Debugln("[listen] @", index, "deliver", pp.BodyLen(), "bytes data to nic") } } - packet.Put() default: - logrus.Warnln("[listen] @", index, "recv unknown proto:", packet.Proto) - packet.Put() + logrus.Warnln("[listen] @", index, "recv unknown proto:", pp.Proto) } - case p.Accept(packet.Dst): + case p.Accept(pp.Dst): if !p.allowtrans { - logrus.Warnln("[listen] @", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) - packet.Put() + logrus.Warnln("[listen] @", index, "refused to trans packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort))) return } // 转发 - lnk := m.router.NextHop(packet.Dst.String()) + lnk := m.router.NextHop(pp.Dst.String()) if lnk == nil { logrus.Warnln("[listen] @", index, "transfer drop packet: nil nexthop") - packet.Put() return } - n, err := lnk.WriteAndPut(packet, true) + n, err := lnk.WritePacket(packet, true) if err == nil { if config.ShowDebugLog { - logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort))) } } else { - logrus.Errorln("[listen] @", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) + logrus.Errorln("[listen] @", index, "trans packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)), "err:", err) } default: - logrus.Warnln("[listen] @", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") - packet.Put() + logrus.Warnln("[listen] @", index, "packet dst", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)), "is not in peers") } } diff --git a/gold/link/me.go b/gold/link/me.go index 1bf0c82..a89664f 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -10,6 +10,8 @@ import ( "time" "github.com/FloatTech/ttl" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" "github.com/fumiama/water/waterutil" "github.com/sirupsen/logrus" @@ -46,13 +48,19 @@ type Me struct { // 本机路由表 router *Router // 本机未接收完全分片池 - recving *ttl.Cache[uint64, *head.Packet] + recving *ttl.Cache[uint64, *orbyte.Item[head.Packet]] // 抗重放攻击记录池 - recved *ttl.Cache[uint64, bool] + recved *ttl.Cache[uint64, struct{}] // 本机上层配置 srcport, dstport, mtu, speedloop uint16 // 报头掩码 mask uint64 + // 本机总接收字节数 + recvtotlcnt uint64 + // 上一次触发循环计数时间 + recvlooptime int64 + // 本机总接收数据包计数 + recvloopcnt uintptr // 是否进行 base16384 编码 base14 bool // 本机网络端点初始化配置 @@ -122,12 +130,13 @@ func NewMe(cfg *MyConfig) (m Me) { ) } m.mask = cfg.Mask + m.recvlooptime = time.Now().UnixMilli() m.base14 = cfg.Base14 var buf [8]byte binary.BigEndian.PutUint64(buf[:], m.mask) logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:])) - m.recving = ttl.NewCache[uint64, *head.Packet](time.Second * 30) - m.recved = ttl.NewCache[uint64, bool](time.Second * 30) + m.recving = ttl.NewCache[uint64, *orbyte.Item[head.Packet]](time.Second * 10) + m.recved = ttl.NewCache[uint64, struct{}](time.Minute) return } @@ -154,6 +163,7 @@ func (m *Me) Restart() error { } m.me = ip m.subnet = *cidr + m.recvlooptime = time.Now().UnixMilli() m.conn, err = m.listen() return err } @@ -280,11 +290,10 @@ func (m *Me) sendAllSameDst(packet []byte) (n int) { logrus.Warnln("[me] drop packet to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), ": nil nexthop") return } - pcp := helper.MakeBytes(len(packet)) - copy(pcp, packet) - go func(packet []byte) { - defer helper.PutBytes(packet) - _, err := lnk.WriteAndPut(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false) + pcp := pbuf.NewBytes(len(packet)) + copy(pcp.Bytes(), packet) + go func(packet pbuf.Bytes) { + _, err := lnk.WritePacket(head.NewPacketPartial(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false) if err != nil { logrus.Warnln("[me] write to peer", lnk.peerip, "err:", err) } diff --git a/gold/link/nat.go b/gold/link/nat.go index 85901b0..1178adf 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -1,6 +1,7 @@ package link import ( + "bytes" "encoding/json" "sync/atomic" "time" @@ -12,6 +13,8 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/p2p" "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" ) // 保持 NAT @@ -35,7 +38,7 @@ func (l *Link) keepAlive(dur int64) { logrus.Infoln("[nat] re-connect me succeeded") } } - n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, []byte{byte(head.HelloPing)}), false) + n, err := l.WritePacket(head.NewPacketPartial(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, pbuf.ParseBytes(byte(head.HelloPing))), false) if err == nil { logrus.Infoln("[nat] send", n, "bytes keep alive packet") } else { @@ -131,12 +134,14 @@ func (l *Link) onQuery(packet []byte) { logrus.Infoln("[nat] query wrap", len(notify), "notify") w := helper.SelectWriter() _ = json.NewEncoder(w).Encode(¬ify) - _, err = l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false) + _, err = l.WritePacket(head.NewPacketPartial( + head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, + pbuf.BufferItemToBytes((*orbyte.Item[bytes.Buffer])(w).Trans()), + ), false) if err != nil { logrus.Errorln("[nat] notify peer", l, "err:", err) return } - helper.PutWriter(w) } } @@ -152,7 +157,10 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) { t := time.NewTicker(tick) for range t.C { logrus.Infoln("[nat] query send query to", l.peerip) - _, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false) + _, err = l.WritePacket(head.NewPacketPartial( + head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, + pbuf.ParseBytes(data...), + ), false) if err != nil { logrus.Errorln("[nat] query write err:", err) } diff --git a/gold/link/peer.go b/gold/link/peer.go index 7dbb2d1..864adaf 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -7,6 +7,7 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/p2p" curve "github.com/fumiama/go-x25519" + "github.com/fumiama/orbyte" "github.com/sirupsen/logrus" "golang.org/x/crypto/chacha20poly1305" ) @@ -49,7 +50,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) { } if !cfg.NoPipe { - l.pipe = make(chan *head.Packet, 32) + l.pipe = make(chan *orbyte.Item[head.Packet], 65536) } var k, p []byte if cfg.PubicKey != nil { diff --git a/gold/link/recv.go b/gold/link/recv.go index 77c495c..2f6be93 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -9,129 +9,123 @@ import ( "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/gold/head" base14 "github.com/fumiama/go-base16384" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" "github.com/sirupsen/logrus" ) // Read 从 peer 收包 -func (l *Link) Read() *head.Packet { +func (l *Link) Read() *orbyte.Item[head.Packet] { return <-l.pipe } -func (m *Me) wait(data []byte) *head.Packet { - if len(data) < head.PacketHeadLen { // not a valid packet +func (m *Me) wait(data pbuf.Bytes) *orbyte.Item[head.Packet] { + if data.Len() < head.PacketHeadLen { // not a valid packet if config.ShowDebugLog { - logrus.Debugln("[recv] invalid data len", len(data)) + logrus.Debugln("[recv] invalid data len", data.Len()) } return nil } bound := 64 endl := "..." - if len(data) < bound { - bound = len(data) + if data.Len() < bound { + bound = data.Len() endl = "." } if config.ShowDebugLog { - logrus.Debugln("[recv] data bytes, len", len(data), "val", hex.EncodeToString(data[:bound]), endl) + logrus.Debugln("[recv] data bytes, len", data.Len(), "val", hex.EncodeToString(data.Bytes()[:bound]), endl) } if m.base14 { - data = base14.Decode(data) - if len(data) < bound { - bound = len(data) + data = pbuf.ParseBytes(base14.Decode(data.Bytes())...) + if data.Len() < bound { + bound = data.Len() endl = "." } if config.ShowDebugLog { - logrus.Debugln("[recv] data b14ed, len", len(data), "val", hex.EncodeToString(data[:bound]), endl) + logrus.Debugln("[recv] data b14ed, len", data.Len(), "val", hex.EncodeToString(data.Bytes()[:bound]), endl) } - if len(data) < head.PacketHeadLen { // not a valid packet + if data.Len() < head.PacketHeadLen { // not a valid packet if config.ShowDebugLog { - logrus.Debugln("[recv] invalid data len", len(data)) + logrus.Debugln("[recv] invalid data len", data.Len()) } return nil } } - seq, data := m.xordec(data) - if len(data) < bound { - bound = len(data) + seq, dat := m.xordec(data.Trans().Bytes()) + if len(dat) < bound { + bound = len(dat) endl = "." } if config.ShowDebugLog { - logrus.Debugln("[recv] data xored, len", len(data), "val", hex.EncodeToString(data[:bound]), endl) + logrus.Debugln("[recv] data xored, len", len(dat), "val", hex.EncodeToString(dat[:bound]), endl) } - if len(data) < head.PacketHeadLen { // not a valid packet + header, err := head.ParsePacketHeader(dat) + if err != nil { // not a valid packet if config.ShowDebugLog { - logrus.Debugln("[recv] invalid data len", len(data)) + logrus.Debugln("[recv] invalid packet header:", err) } return nil } - flags := head.Flags(data) - if !flags.IsValid() { + if !header.Pointer().Flags.IsValid() { if config.ShowDebugLog { - logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) + logrus.Debugln("[recv] drop invalid flags packet:", header.Pointer().Flags) } return nil } - crc := head.CRC64(data) + crc := header.Pointer().CRC64() crclog := crc crc ^= (uint64(seq) << 16) if config.ShowDebugLog { logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc) } - if m.recved.Get(crc) { + if _, got := m.recved.GetOrSet(crc, struct{}{}); got { if config.ShowDebugLog { logrus.Debugln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) } return nil } - m.recved.Set(crc, true) if config.ShowDebugLog { - logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) + logrus.Debugln( + "[recv]", strconv.FormatUint(crc, 16), + len(dat), "bytes data with flag", header.Pointer().Flags, + "offset", header.Pointer().Flags.Offset(), + ) } - if flags.IsSingle() || flags.NoFrag() { - h := head.SelectPacket() - _, err := h.Unmarshal(data) - if err != nil { - logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err) + if header.Pointer().Flags.IsSingle() || header.Pointer().Flags.NoFrag() { + ok := header.Pointer().ParseData(dat) + if !ok { + logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unexpected !ok") return nil } - return h + if config.ShowDebugLog { + logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), len(dat), "bytes full data waited") + } + return header } crchash := crc64.New(crc64.MakeTable(crc64.ISO)) - _, _ = crchash.Write(head.Hash(data)) + _, _ = crchash.Write(head.Hash(data.Bytes())) var buf [4]byte binary.LittleEndian.PutUint32(buf[:], seq) _, _ = crchash.Write(buf[:]) hsh := crchash.Sum64() - h := m.recving.Get(hsh) - if h != nil { - if config.ShowDebugLog { - logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get another frag part of", strconv.FormatUint(hsh, 16)) - } - ok, err := h.Unmarshal(data) - if err == nil { - if ok { - m.recving.Delete(hsh) - if config.ShowDebugLog { - logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "all parts of", strconv.FormatUint(hsh, 16), "has reached") - } - return h - } - } else { - h.Put() - logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err) - } - return nil + h, got := m.recving.GetOrSet(hsh, header) + if got && h == header { + panic("unexpected multi-put found") } if config.ShowDebugLog { - logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get new frag part of", strconv.FormatUint(hsh, 16)) + logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get frag part of", strconv.FormatUint(hsh, 16), "isnew:", !got) } - h = head.SelectPacket() - _, err := h.Unmarshal(data) - if err != nil { - h.Put() - logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err) + ok := h.Pointer().ParseData(dat) + if !ok { + if config.ShowDebugLog { + logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "wait other frag parts of", strconv.FormatUint(hsh, 16), "isnew:", !got) + } return nil } - m.recving.Set(hsh, h) - return nil + m.recving.Delete(hsh) + if config.ShowDebugLog { + logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "all parts of", strconv.FormatUint(hsh, 16), "has reached") + } + return h } diff --git a/gold/link/send.go b/gold/link/send.go index 0cf183f..986b445 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math/rand" + "runtime" "github.com/klauspost/compress/zstd" "github.com/sirupsen/logrus" @@ -17,6 +18,8 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/helper" base14 "github.com/fumiama/go-base16384" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" ) var ( @@ -24,15 +27,19 @@ var ( ErrTTL = errors.New("ttl exceeded") ) -// WriteAndPut 向 peer 发包并将包放回缓存池 -func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { - defer p.Put() - teatype := l.randkeyidx() - sndcnt := uint16(l.incgetsndcnt()) +func randseq(i uint16) uint32 { var buf [4]byte _, _ = crand.Read(buf[:2]) - binary.BigEndian.PutUint16(buf[2:4], sndcnt) - seq := binary.BigEndian.Uint32(buf[:]) + binary.BigEndian.PutUint16(buf[2:4], i) + return binary.BigEndian.Uint32(buf[:]) +} + +// WritePacket 向 peer 发包 +func (l *Link) WritePacket(p *orbyte.Item[head.Packet], istransfer bool) (n int, err error) { + pp := p.Pointer() + teatype := l.randkeyidx() + sndcnt := uint16(l.incgetsndcnt()) + seq := randseq(sndcnt) mtu := l.mtu if l.mturandomrange > 0 { mtu -= uint16(rand.Intn(int(l.mturandomrange))) @@ -41,48 +48,48 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { logrus.Debugln("[send] mtu:", mtu, ", addt:", sndcnt&0x07ff, ", key index:", teatype) } if !istransfer { - l.encrypt(p, sndcnt, teatype) + l.encrypt(pp, sndcnt, teatype) } delta := (int(mtu) - head.PacketHeadLen) & 0x0000fff8 if delta <= 0 { logrus.Warnln("[send] reset invalid data frag len", delta, "to 8") delta = 8 } - remlen := p.BodyLen() + remlen := pp.BodyLen() if remlen <= delta { return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false, seq) } - if istransfer && p.Flags.DontFrag() && remlen > delta { + if istransfer && pp.Flags.DontFrag() && remlen > delta { return 0, ErrDropBigDontFragPkt } - ttl := p.TTL + ttl := pp.TTL totl := uint32(remlen) pos := 0 - packet := p.Copy() + packet := head.ParsePacket(pp.ShallowCopy()) for remlen > delta { remlen -= delta if config.ShowDebugLog { logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen) } - packet.CropBody(pos, pos+delta) + packet.Pointer().CropBody(pos, pos+delta) cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true, seq) n += cnt if err != nil { return n, err } - packet.TTL = ttl + packet.Pointer().TTL = ttl pos += delta } - packet.Put() if remlen > 0 { if config.ShowDebugLog { logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]") } - p.CropBody(pos, pos+remlen) + pp.CropBody(pos, pos+remlen) cnt := 0 cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false, seq) n += cnt } + runtime.KeepAlive(p) return n, err } @@ -94,24 +101,27 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) { data := p.Body() if l.usezstd { w := helper.SelectWriter() - defer helper.PutWriter(w) enc, _ := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedFastest)) _, _ = io.Copy(enc, bytes.NewReader(data)) enc.Close() - data = w.Bytes() + data = w.TransBytes() if config.ShowDebugLog { logrus.Debugln("[send] data len after zstd:", len(data)) } } - p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data), true) + p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data).Trans().Bytes()) if config.ShowDebugLog { logrus.Debugln("[send] data len after xchacha20:", p.BodyLen(), "addt:", sndcnt) } } // write 向 peer 发包 -func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) { - if p.DecreaseAndGetTTL() <= 0 { +func (l *Link) write( + p *orbyte.Item[head.Packet], teatype uint8, additional uint16, + datasz uint32, offset uint16, istransfer, + hasmore bool, seq uint32, +) (int, error) { + if p.Pointer().DecreaseAndGetTTL() <= 0 { return 0, ErrTTL } if l.doublepacket { @@ -121,26 +131,28 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui } // write 向 peer 发一个包 -func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) { +func (l *Link) writeonce( + p *orbyte.Item[head.Packet], teatype uint8, additional uint16, + datasz uint32, offset uint16, + istransfer, hasmore bool, seq uint32, +) (int, error) { peerep := l.endpoint if helper.IsNilInterface(peerep) { - return 0, errors.New("nil endpoint of " + p.Dst.String()) + return 0, errors.New("nil endpoint of " + p.Pointer().Dst.String()) } - var d []byte - var cl func() + var d pbuf.Bytes // TODO: now all packet allow frag, adapt to DF if istransfer { - d, cl = p.Marshal(nil, 0, 0, 0, offset, false, hasmore) + d = p.Pointer().MarshalWith(nil, 0, 0, 0, offset, false, hasmore) } else { - d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore) + d = p.Pointer().MarshalWith(l.me.me, teatype, additional, datasz, offset, false, hasmore) } - defer cl() bound := 64 endl := "..." - if len(d) < bound { - bound = len(d) + if d.Len() < bound { + bound = d.Len() endl = "." } conn := l.me.conn @@ -148,17 +160,15 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas return 0, io.ErrClosedPipe } if config.ShowDebugLog { - - logrus.Debugln("[send] write", len(d), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64())) - logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) + logrus.Debugln("[send] write", d.Len(), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.Pointer().CRC64())) + logrus.Debugln("[send] data bytes", hex.EncodeToString(d.Bytes()[:bound]), endl) } - d = l.me.xorenc(d, seq) + d = l.me.xorenc(d.Bytes(), seq) if l.me.base14 { - d = base14.Encode(d) + d = pbuf.ParseBytes(base14.Encode(d.Bytes())...) } if config.ShowDebugLog { - logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) + logrus.Debugln("[send] data xored", hex.EncodeToString(d.Bytes()[:bound]), endl) } - defer helper.PutBytes(d) - return conn.WriteToPeer(d, peerep) + return conn.WriteToPeer(d.Trans().Bytes(), peerep) } diff --git a/gold/p2p/tcp/pdu.go b/gold/p2p/tcp/pdu.go index 61b0450..d827d5f 100644 --- a/gold/p2p/tcp/pdu.go +++ b/gold/p2p/tcp/pdu.go @@ -1,14 +1,18 @@ package tcp import ( + "bytes" "encoding/binary" "errors" "io" "net" + "runtime" "time" "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" "github.com/sirupsen/logrus" ) @@ -33,17 +37,19 @@ var ( type packet struct { typ packetType len uint16 - dat []byte + dat pbuf.Bytes io.ReaderFrom io.WriterTo } func (p *packet) pack() (net.Buffers, func()) { - d, cl := helper.OpenWriterF(func(w *helper.Writer) { + d := helper.NewWriterF(func(w *helper.Writer) { w.WriteByte(byte(p.typ)) w.WriteUInt16(p.len) }) - return net.Buffers{magicbuf, d, p.dat}, cl + return net.Buffers{magicbuf, d.Bytes(), p.dat.Bytes()}, func() { + runtime.KeepAlive(d) + } } func (p *packet) Read(_ []byte) (int, error) { @@ -81,7 +87,7 @@ func (p *packet) ReadFrom(r io.Reader) (n int64, err error) { if err != nil { return } - p.dat = w.Bytes() + p.dat = pbuf.BufferItemToBytes((*orbyte.Item[bytes.Buffer])(w).Trans()) return } diff --git a/gold/p2p/tcp/tcp.go b/gold/p2p/tcp/tcp.go index 764e7fe..83fb4f9 100644 --- a/gold/p2p/tcp/tcp.go +++ b/gold/p2p/tcp/tcp.go @@ -14,7 +14,7 @@ import ( "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/gold/p2p" - "github.com/fumiama/WireGold/helper" + "github.com/fumiama/orbyte/pbuf" ) type EndPoint struct { @@ -377,9 +377,8 @@ func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) { if p.pckt.typ == packetTypeNormal { break } - defer helper.PutBytes(p.pckt.dat) } - n := copy(b, p.pckt.dat) + n := copy(b, p.pckt.dat.Bytes()) return n, p.addr, nil } @@ -453,7 +452,7 @@ RECONNECT: cnt, err := io.Copy(tcpconn, &packet{ typ: packetTypeNormal, len: uint16(len(b)), - dat: b, + dat: pbuf.ParseBytes(b...), }) if err != nil { if subc == nil { diff --git a/helper/pool.go b/helper/pool.go index 7261bfe..0130511 100644 --- a/helper/pool.go +++ b/helper/pool.go @@ -1,45 +1,10 @@ package helper import ( - "bytes" - "sync" + "github.com/fumiama/orbyte/pbuf" ) -// https://github.com/Mrs4s/MiraiGo/blob/master/binary/pool.go - -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(Writer) - }, -} - -func MakeBytes(sz int) []byte { - w := SelectWriter() - b := w.Bytes() - if cap(b) >= sz { - return b[:sz] - } - return make([]byte, sz) -} - -func PutBytes(b []byte) { - PutWriter((*Writer)(bytes.NewBuffer(b))) -} - // SelectWriter 从池中取出一个 Writer func SelectWriter() *Writer { - // 因为 bufferPool 定义有 New 函数 - // 所以 bufferPool.Get() 永不为 nil - // 不用判空 - return bufferPool.Get().(*Writer) -} - -// PutWriter 将 Writer 放回池中 -func PutWriter(w *Writer) { - // See https://golang.org/issue/23199 - const maxSize = 1 << 16 - if (*bytes.Buffer)(w).Cap() < maxSize { // 对于大Buffer直接丢弃 - w.Reset() - bufferPool.Put(w) - } + return (*Writer)(pbuf.NewBuffer(nil)) } diff --git a/helper/writer.go b/helper/writer.go index 9c6a123..f1607dd 100644 --- a/helper/writer.go +++ b/helper/writer.go @@ -6,52 +6,22 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "io" - "unsafe" + + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" ) // Writer 写入 -type Writer bytes.Buffer +type Writer orbyte.Item[bytes.Buffer] -func NewWriterF(f func(writer *Writer)) []byte { +func NewWriterF(f func(writer *Writer)) pbuf.Bytes { w := SelectWriter() f(w) - b := append([]byte(nil), w.Bytes()...) - w.put() - return b -} - -// OpenWriterF must call func cl to close -func OpenWriterF(f func(*Writer)) (b []byte, cl func()) { - w := SelectWriter() - f(w) - return w.Bytes(), w.put -} - -func (w *Writer) FillUInt16() (pos int) { - pos = w.Len() - (*bytes.Buffer)(w).Write([]byte{0, 0}) - return -} - -func (w *Writer) WriteUInt16At(pos int, v uint16) { - newdata := (*bytes.Buffer)(w).Bytes()[pos:] - binary.LittleEndian.PutUint16(newdata, v) -} - -func (w *Writer) FillUInt32() (pos int) { - pos = w.Len() - (*bytes.Buffer)(w).Write([]byte{0, 0, 0, 0}) - return -} - -func (w *Writer) WriteUInt32At(pos int, v uint32) { - newdata := (*bytes.Buffer)(w).Bytes()[pos:] - binary.LittleEndian.PutUint32(newdata, v) + return pbuf.BufferItemToBytes((*orbyte.Item[bytes.Buffer])(w).Trans()) } func (w *Writer) Write(b []byte) (n int, err error) { - return (*bytes.Buffer)(w).Write(b) + return (*orbyte.Item[bytes.Buffer])(w).Pointer().Write(b) } func (w *Writer) WriteHex(h string) { @@ -60,7 +30,7 @@ func (w *Writer) WriteHex(h string) { } func (w *Writer) WriteByte(b byte) error { - return (*bytes.Buffer)(w).WriteByte(b) + return (*orbyte.Item[bytes.Buffer])(w).Pointer().WriteByte(b) } func (w *Writer) WriteUInt16(v uint16) { @@ -83,95 +53,25 @@ func (w *Writer) WriteUInt64(v uint64) { func (w *Writer) WriteString(v string) { //w.WriteUInt32(uint32(len(v) + 4)) - (*bytes.Buffer)(w).WriteString(v) -} - -func (w *Writer) WriteStringShort(v string) { - w.WriteUInt16(uint16(len(v))) - (*bytes.Buffer)(w).WriteString(v) -} - -func (w *Writer) WriteBool(b bool) { - if b { - w.WriteByte(0x01) - } else { - w.WriteByte(0x00) - } -} - -func (w *Writer) WriteBytesShort(data []byte) { - w.WriteUInt16(uint16(len(data))) - w.Write(data) + (*orbyte.Item[bytes.Buffer])(w).Pointer().WriteString(v) } func (w *Writer) Len() int { - return (*bytes.Buffer)(w).Len() + return (*orbyte.Item[bytes.Buffer])(w).Pointer().Len() } -func (w *Writer) Bytes() []byte { - return (*bytes.Buffer)(w).Bytes() +func (w *Writer) UnsafeBytes() []byte { + return (*orbyte.Item[bytes.Buffer])(w).Pointer().Bytes() +} + +func (w *Writer) TransBytes() []byte { + return (*orbyte.Item[bytes.Buffer])(w).Trans().Pointer().Bytes() } func (w *Writer) Reset() { - (*bytes.Buffer)(w).Reset() + (*orbyte.Item[bytes.Buffer])(w).Pointer().Reset() } func (w *Writer) Grow(n int) { - (*bytes.Buffer)(w).Grow(n) -} - -func (w *Writer) Skip(n int) (int, error) { - b := (*buffer)(unsafe.Pointer(w)) - b.lastRead = opInvalid - if len(b.buf) <= b.off { - // Buffer is empty, reset to recover space. - w.Reset() - if n == 0 { - return 0, nil - } - return 0, io.EOF - } - n = minnum(n, len(b.buf[b.off:])) - b.off += n - if n > 0 { - b.lastRead = opRead - } - return n, nil -} - -func (w *Writer) put() { - PutWriter(w) -} - -// A Buffer is a variable-sized buffer of bytes with Read and Write methods. -// The zero value for Buffer is an empty buffer ready to use. -type buffer struct { - buf []byte // contents are the bytes buf[off : len(buf)] - off int // read at &buf[off], write at &buf[len(buf)] - lastRead readOp // last read operation, so that Unread* can work correctly. -} - -// The readOp constants describe the last action performed on -// the buffer, so that UnreadRune and UnreadByte can check for -// invalid usage. opReadRuneX constants are chosen such that -// converted to int they correspond to the rune size that was read. -type readOp int8 - -// Don't use iota for these, as the values need to correspond with the -// names and comments, which is easier to see when being explicit. -const ( - opRead readOp = -1 // Any other read operation. - opInvalid readOp = 0 // Non-read operation. - opReadRune1 readOp = 1 // Read rune of size 1. - opReadRune2 readOp = 2 // Read rune of size 2. - opReadRune3 readOp = 3 // Read rune of size 3. - opReadRune4 readOp = 4 // Read rune of size 4. -) - -// minnum 返回两数最小值,该函数将被内联 -func minnum[T int | int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64](a, b T) T { - if a > b { - return b - } - return a + (*orbyte.Item[bytes.Buffer])(w).Pointer().Grow(n) } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 726c2d1..db84d16 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -12,6 +12,8 @@ import ( _ "github.com/fumiama/WireGold/gold/p2p/tcp" // support tcp connection _ "github.com/fumiama/WireGold/gold/p2p/udp" // support udp connection _ "github.com/fumiama/WireGold/gold/p2p/udplite" // support udplite connection + "github.com/fumiama/orbyte" + "github.com/fumiama/orbyte/pbuf" "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/gold/head" @@ -21,7 +23,7 @@ import ( type Tunnel struct { l *link.Link in chan []byte - out chan *head.Packet + out chan *orbyte.Item[head.Packet] outcache []byte peerip net.IP src uint16 @@ -33,7 +35,7 @@ func Create(me *link.Me, peer string) (s Tunnel, err error) { s.l, err = me.Connect(peer) if err == nil { s.in = make(chan []byte, 4) - s.out = make(chan *head.Packet, 4) + s.out = make(chan *orbyte.Item[head.Packet], 4) s.peerip = net.ParseIP(peer) } else { logrus.Errorln("[tunnel] create err:", err) @@ -73,12 +75,11 @@ func (s *Tunnel) Read(p []byte) (int, error) { if pkt == nil { return 0, io.EOF } - defer pkt.Put() - if pkt.BodyLen() < 4 { - logrus.Warnln("[tunnel] unexpected packet data len", pkt.BodyLen(), "content", hex.EncodeToString(pkt.Body())) + if pkt.Pointer().BodyLen() < 4 { + logrus.Warnln("[tunnel] unexpected packet data len", pkt.Pointer().BodyLen(), "content", hex.EncodeToString(pkt.Pointer().Body())) return 0, io.EOF } - d = pkt.Body()[4:] + d = pkt.Pointer().Body()[4:] } if d != nil { if len(p) >= len(d) { @@ -125,8 +126,8 @@ func (s *Tunnel) handleWrite() { binary.LittleEndian.PutUint32(buf[:4], seq) seq++ copy(buf[4:], b[:s.mtu-4]) - _, err := s.l.WriteAndPut( - head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf), false, + _, err := s.l.WritePacket( + head.NewPacketPartial(head.ProtoData, s.src, s.peerip, s.dest, pbuf.ParseBytes(buf...)), false, ) if err != nil { logrus.Errorln("[tunnel] seq", seq-1, "write err:", err) @@ -140,8 +141,8 @@ func (s *Tunnel) handleWrite() { binary.LittleEndian.PutUint32(buf[:4], seq) seq++ copy(buf[4:], b) - _, err := s.l.WriteAndPut( - head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf[:len(b)+4]), false, + _, err := s.l.WritePacket( + head.NewPacketPartial(head.ProtoData, s.src, s.peerip, s.dest, pbuf.ParseBytes(buf[:len(b)+4]...)), false, ) if err != nil { logrus.Errorln("[tunnel] seq", seq-1, "write err:", err) @@ -155,7 +156,7 @@ func (s *Tunnel) handleWrite() { func (s *Tunnel) handleRead() { seq := uint32(0) - seqmap := make(map[uint32]*head.Packet) + seqmap := make(map[uint32]*orbyte.Item[head.Packet]) for { if p, ok := seqmap[seq]; ok { if config.ShowDebugLog { @@ -173,14 +174,15 @@ func (s *Tunnel) handleRead() { } end := 64 endl := "..." - if p.BodyLen() < 64 { - end = p.BodyLen() + pp := p.Pointer() + if pp.BodyLen() < 64 { + end = pp.BodyLen() endl = "." } if config.ShowDebugLog { - logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Body()[:end]), endl) + logrus.Debugln("[tunnel] read recv", hex.EncodeToString(pp.Body()[:end]), endl) } - recvseq := binary.LittleEndian.Uint32(p.Body()[:4]) + recvseq := binary.LittleEndian.Uint32(pp.Body()[:4]) if recvseq == seq { if config.ShowDebugLog { logrus.Debugln("[tunnel] dispatch seq", seq) diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 11ad3aa..e48586d 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -12,6 +12,7 @@ import ( "time" curve "github.com/fumiama/go-x25519" + "github.com/fumiama/orbyte" "github.com/sirupsen/logrus" "github.com/fumiama/WireGold/gold/link" @@ -196,7 +197,9 @@ func testTunnel(t *testing.T, nw string, isplain, isbase14 bool, pshk *[32]byte, } sendb = make([]byte, 4096) - rand.Read(sendb) + for i := 0; i < 4096; i++ { + sendb[i] = byte(i) + } go tunnme.Write(sendb) buf = make([]byte, 4096) _, err = io.ReadFull(&tunnpeer, buf) @@ -213,7 +216,9 @@ func testTunnel(t *testing.T, nw string, isplain, isbase14 bool, pshk *[32]byte, time.Sleep(time.Second) for i := 0; i < 32; i++ { sendb := make([]byte, 65535) - rand.Read(sendb) + for j := 0; j < 65535; j++ { + sendb[j] = byte(i + j) + } n, _ := tunnme.Write(sendb) sendbufs <- sendb logrus.Debugln("loop", i, "write", n, "bytes") @@ -234,7 +239,9 @@ func testTunnel(t *testing.T, nw string, isplain, isbase14 bool, pshk *[32]byte, i++ } - rand.Read(sendb) + for i := 0; i < 4096; i++ { + sendb[i] = ^byte(i) + } tunnme.Write(sendb) rd := bytes.NewBuffer(nil) @@ -420,7 +427,9 @@ type logFormat struct { // Format implements logrus.Formatter func (f logFormat) Format(entry *logrus.Entry) ([]byte, error) { - buf := helper.SelectWriter() // this writer will not be put back + // this writer will not be put back + + buf := (*orbyte.Item[bytes.Buffer])(helper.SelectWriter()).Trans().Pointer() buf.WriteByte('[') if f.enableColor {