1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-13 05:31:08 +08:00

optimize(all): drop lstnq & impl. orbyte

This commit is contained in:
源文雨
2025-02-25 19:38:16 +09:00
parent 4b60801a0f
commit 9f36504635
22 changed files with 501 additions and 573 deletions

View File

@@ -1,9 +1,16 @@
package head
import "encoding/binary"
import (
"encoding/binary"
"fmt"
)
type PacketFlags uint16
func (pf PacketFlags) String() string {
return fmt.Sprintf("%04x", uint16(pf))
}
func (pf PacketFlags) IsValid() bool {
return pf&0x8000 == 0
}

View File

@@ -5,8 +5,11 @@ import (
"encoding/hex"
"errors"
"net"
"sync/atomic"
blake2b "github.com/fumiama/blake2b-simd"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
@@ -37,6 +40,8 @@ type Packet struct {
DstPort uint16
// Flags 高3位为标志(xDM)低13位为分片偏移
Flags PacketFlags
// 记录还有多少字节未到达
rembytes int32
// Src 源 ip (ipv4)
Src net.IP
// Dst 目的 ip (ipv4)
@@ -48,82 +53,113 @@ type Packet struct {
// crc64 包头字段的 checksum 值,可以认为在一定时间内唯一 (现已更改算法为 md5 但名字未变)
crc64 uint64
// data 承载的数据
data []byte
data pbuf.Bytes
// Data 当前的偏移
a, b int
// 记录还有多少字节未到达
rembytes int
// 是否经由 helper.MakeBytes 创建 Data
buffered bool
}
// NewPacket 生成一个新包
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) (p *Packet) {
p = SelectPacket()
p.Proto = proto
p.TTL = 16
p.SrcPort = srcPort
p.DstPort = dstPort
p.Dst = dst
p.data = data
p.b = len(data)
// NewPacketPartial 从一些预设参数生成一个新包
func NewPacketPartial(
proto uint8, srcPort uint16,
dst net.IP, dstPort uint16,
data pbuf.Bytes,
) (p *orbyte.Item[Packet]) {
p = selectPacket()
pp := p.Pointer()
pp.Proto = proto
pp.TTL = 16
pp.SrcPort = srcPort
pp.DstPort = dstPort
pp.Dst = dst
pp.data = data
pp.b = data.Len()
return
}
// Unmarshal 将 data 的数据解码到自身
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
func ParsePacket(p Packet) *orbyte.Item[Packet] {
return packetPool.Parse(nil, p)
}
func ParsePacketHeader(data []byte) (p *orbyte.Item[Packet], err error) {
if len(data) < 60 {
err = ErrDataLenLT60
return
}
p.crc64 = CRC64(data)
if CalcCRC64(data) != p.crc64 {
p = selectPacket()
pp := p.Pointer()
pp.crc64 = CRC64(data)
if CalcCRC64(data) != pp.crc64 {
err = ErrBadCRCChecksum
return
}
pp.idxdatsz = binary.LittleEndian.Uint32(data[:4])
sz := pp.Len()
if config.ShowDebugLog {
logrus.Debugln("[packet] header data len", sz, "read data len", len(data))
}
pt := binary.LittleEndian.Uint16(data[4:6])
pp.Proto = uint8(pt)
pp.TTL = uint8(pt >> 8)
pp.SrcPort = binary.LittleEndian.Uint16(data[6:8])
pp.DstPort = binary.LittleEndian.Uint16(data[8:10])
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
pp.Flags = flags
pp.Src = make(net.IP, 4)
copy(pp.Src, data[12:16])
pp.Dst = make(net.IP, 4)
copy(pp.Dst, data[16:20])
copy(pp.Hash[:], data[20:52])
switch {
case sz+PacketHeadLen == len(data):
pp.b = sz
pp.rembytes = -1
case pp.rembytes == 0:
pp.data = pbuf.NewBytes(sz)
pp.b = sz
pp.rembytes = int32(sz)
}
return
}
// ParseData 将 data 的数据解码到自身
//
// 必须先调用 ParsePacketHeader
func (p *Packet) ParseData(data []byte) (complete bool) {
sz := p.Len()
if sz == 0 && len(p.data) == 0 {
p.idxdatsz = binary.LittleEndian.Uint32(data[:4])
sz = p.Len()
if sz+52 == len(data) {
p.data = data[52:]
p.b = len(p.data)
p.rembytes = 0
} else {
p.data = helper.MakeBytes(sz)
p.buffered = true
p.b = sz
p.rembytes = sz
}
pt := binary.LittleEndian.Uint16(data[4:6])
p.Proto = uint8(pt)
p.TTL = uint8(pt >> 8)
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
if sz+PacketHeadLen == len(data) {
p.data = pbuf.ParseBytes(data[PacketHeadLen:]...)
return true
}
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
if config.ShowDebugLog {
logrus.Debugln("[packet] parse data flags", flags, "off", flags.Offset())
}
if flags.ZeroOffset() {
p.Flags = flags
p.Src = make(net.IP, 4)
copy(p.Src, data[12:16])
p.Dst = make(net.IP, 4)
copy(p.Dst, data[16:20])
copy(p.Hash[:], data[20:52])
}
if p.rembytes > 0 {
p.rembytes -= copy(p.data[flags.Offset():], data[PacketHeadLen:])
if config.ShowDebugLog {
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
logrus.Debugln("[packet] parse data set zero offset flags", flags)
}
}
complete = p.rembytes == 0
rembytes := atomic.LoadInt32(&p.rembytes)
if rembytes > 0 {
n := int32(copy(p.data.Bytes()[flags.Offset():], data[PacketHeadLen:]))
newrem := rembytes - n
for !atomic.CompareAndSwapInt32(&p.rembytes, rembytes, newrem) {
rembytes = atomic.LoadInt32(&p.rembytes)
newrem = rembytes - n
}
if config.ShowDebugLog {
logrus.Debugln("[packet] copied frag", hex.EncodeToString(data[20:52]), "rembytes:", p.rembytes)
}
}
return
return p.rembytes <= 0
}
// DecreaseAndGetTTL TTL 自减后返回
@@ -132,9 +168,13 @@ func (p *Packet) DecreaseAndGetTTL() uint8 {
return p.TTL
}
// Marshal 将自身数据编码为 []byte
// MarshalWith 补全剩余参数, 将自身数据编码为 []byte
// offset 必须为 8 的倍数,表示偏移的 8 位
func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
func (p *Packet) MarshalWith(
src net.IP, teatype uint8, additional uint16,
datasz uint32, offset uint16,
dontfrag, hasmore bool,
) pbuf.Bytes {
if src != nil {
p.Src = src
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
@@ -148,8 +188,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
offset |= 0x2000
}
p.Flags = PacketFlags(offset)
return helper.OpenWriterF(func(w *helper.Writer) {
return helper.NewWriterF(func(w *helper.Writer) {
w.WriteUInt32(p.idxdatsz)
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
w.WriteUInt16(p.SrcPort)
@@ -158,7 +197,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
w.Write(p.Src.To4())
w.Write(p.Dst.To4())
w.Write(p.Hash[:])
p.crc64 = CalcCRC64(w.Bytes())
p.crc64 = CalcCRC64(w.UnsafeBytes())
w.WriteUInt64(p.crc64)
w.Write(p.Body())
})
@@ -189,6 +228,7 @@ func (p *Packet) IsVaildHash() bool {
var sum [32]byte
_ = h.Sum(sum[:0])
if config.ShowDebugLog {
logrus.Debugln("[packet] sum data len:", len(p.Body()))
logrus.Debugln("[packet] sum calulated:", hex.EncodeToString(sum[:]))
logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:]))
}
@@ -210,45 +250,28 @@ func (p *Packet) Len() int {
return int(p.idxdatsz & 0xffff)
}
// Put 将自己放回池中
func (p *Packet) Put() {
PutPacket(p)
}
func (p *Packet) CRC64() uint64 {
return p.crc64
}
// Body returns data
func (p *Packet) Body() []byte {
return p.data[p.a:p.b]
return p.data.Bytes()[p.a:p.b]
}
func (p *Packet) BodyLen() int {
return p.b - p.a
}
func (p *Packet) SetBody(b []byte, buffered bool) {
func (p *Packet) SetBody(b []byte) {
p.a = 0
p.b = len(b)
if len(b) <= cap(p.data) {
p.data = p.data[:len(b)]
copy(p.data, b)
if buffered {
helper.PutBytes(b)
}
return
}
if p.buffered {
helper.PutBytes(p.data)
}
p.data = b
p.buffered = buffered
p.data = pbuf.ParseBytes(b...)
}
func (p *Packet) CropBody(a, b int) {
if b > len(p.data) {
b = len(p.data)
if b > p.data.Len() {
b = p.data.Len()
}
if a < 0 || b < 0 || a > b {
return
@@ -256,16 +279,8 @@ func (p *Packet) CropBody(a, b int) {
p.a, p.b = a, b
}
func (p *Packet) Copy() *Packet {
newp := SelectPacket()
*newp = *p
newp.buffered = false
return newp
}
func (p *Packet) CopyWithBody() *Packet {
newp := p.Copy()
newp.data = helper.MakeBytes(len(p.data))
copy(newp.data, p.data)
func (p *Packet) ShallowCopy() (newp Packet) {
newp = *p
newp.data = p.data.Ref()
return newp
}

View File

@@ -2,14 +2,20 @@ package head
import (
crand "crypto/rand"
"encoding/hex"
"math/rand"
"net"
"testing"
"github.com/fumiama/orbyte/pbuf"
)
func TestMarshalUnmarshal(t *testing.T) {
data := make([]byte, 4096)
_, err := crand.Read(data)
data := pbuf.NewBytes(4096)
n, err := crand.Read(data.Bytes())
if n != 4096 {
t.Fatal("unexpected")
}
if err != nil {
t.Fatal(err)
}
@@ -28,40 +34,40 @@ func TestMarshalUnmarshal(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p := NewPacket(proto, srcPort, dst, dstPort, data)
p.FillHash()
d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false)
p = SelectPacket()
ok, err := p.Unmarshal(d)
cl()
p := NewPacketPartial(proto, srcPort, dst, dstPort, data.SliceTo(i))
p.Pointer().FillHash()
d := p.Pointer().MarshalWith(src, teatype, uint16(i), uint32(i), 0, true, false)
t.Log("data:", hex.EncodeToString(d.Bytes()))
p, err := ParsePacketHeader(d.Bytes())
if err != nil {
t.Fatal("index", i, err)
}
ok := p.Pointer().ParseData(d.Bytes())
if !ok {
t.Fatal("index", i)
}
if err != nil {
t.Fatal(err)
if !p.Pointer().IsVaildHash() {
t.Fatal("index", i, "expect body", hex.EncodeToString(data.SliceTo(i).Bytes()), "got", hex.EncodeToString(p.Pointer().Body()))
}
if !p.IsVaildHash() {
if p.Pointer().Proto != proto {
t.Fatal("index", i)
}
if p.Proto != proto {
if p.Pointer().CipherIndex() != teatype {
t.Fatal("index", i, "expect", teatype, "got", p.Pointer().CipherIndex())
}
if p.Pointer().SrcPort != srcPort {
t.Fatal("index", i)
}
if p.CipherIndex() != teatype {
t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex())
}
if p.SrcPort != srcPort {
if p.Pointer().DstPort != dstPort {
t.Fatal("index", i)
}
if p.DstPort != dstPort {
if !p.Pointer().Src.Equal(src) {
t.Fatal("index", i)
}
if !p.Src.Equal(src) {
if !p.Pointer().Dst.Equal(dst) {
t.Fatal("index", i)
}
if !p.Dst.Equal(dst) {
t.Fatal("index", i)
}
if p.AdditionalData() != uint16(i) {
if p.Pointer().AdditionalData() != uint16(i) {
t.Fatal("index", i)
}
}

View File

@@ -1,31 +1,37 @@
package head
import (
"sync"
"github.com/fumiama/WireGold/helper"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
)
var packetPool = sync.Pool{
New: func() interface{} {
return new(Packet)
},
type packetPooler struct {
orbyte.Pooler[Packet]
}
// SelectPacket 从池中取出一个 Packet
func SelectPacket() *Packet {
return packetPool.Get().(*Packet)
func (packetPooler) New(_ any, pooled Packet) Packet {
return pooled
}
// PutPacket 将 Packet 放回池中
func PutPacket(p *Packet) {
func (packetPooler) Parse(obj any, _ Packet) Packet {
return obj.(Packet)
}
func (packetPooler) Reset(p *Packet) {
p.idxdatsz = 0
if p.buffered {
helper.PutBytes(p.data)
p.buffered = false
}
p.data = pbuf.Bytes{}
p.a, p.b = 0, 0
p.data = nil
p.rembytes = 0
packetPool.Put(p)
}
func (packetPooler) Copy(dst, src *Packet) {
*dst = *src
dst.data = src.data.Copy()
}
var packetPool = orbyte.NewPool[Packet](packetPooler{})
// selectPacket 从池中取出一个 Packet
func selectPacket() *orbyte.Item[Packet] {
return packetPool.New(nil)
}