diff --git a/gold/head/packet.go b/gold/head/packet.go index 3d7d1fd..0699413 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -70,10 +70,14 @@ type Packet struct { Hash [32]byte // crc64 包头字段的 checksum 值,可以认为在一定时间内唯一 crc64 uint64 - // Data 承载的数据 - Data []byte + // data 承载的数据 + data []byte + // Data 当前的偏移 + a, b int // 记录还有多少字节未到达 rembytes int + // 是否经由 helper.MakeBytes 创建 Data + buffered bool } // NewPacket 生成一个新包 @@ -85,7 +89,8 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b p.SrcPort = srcPort p.DstPort = dstPort p.Dst = dst - p.Data = data + p.data = data + p.b = len(data) return } @@ -101,16 +106,19 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { return } - sz := p.idxdatsz & 0x0000ffff - if sz == 0 && len(p.Data) == 0 { + sz := p.Len() + if sz == 0 && len(p.data) == 0 { p.idxdatsz = binary.LittleEndian.Uint32(data[:4]) - sz = p.idxdatsz & 0x0000ffff - if int(sz)+52 == len(data) { - p.Data = data[52:] + sz = p.Len() + if sz+52 == len(data) { + p.data = data[52:] + p.b = len(p.data) p.rembytes = 0 } else { - p.Data = make([]byte, sz) - p.rembytes = int(sz) + p.data = helper.MakeBytes(sz) + p.buffered = true + p.b = sz + p.rembytes = sz } pt := binary.LittleEndian.Uint16(data[4:6]) p.Proto = uint8(pt) @@ -131,7 +139,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { } if p.rembytes > 0 { - p.rembytes -= copy(p.Data[flags.Offset():], data[60:]) + p.rembytes -= copy(p.data[flags.Offset():], data[60:]) logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) } @@ -171,14 +179,14 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui w.Write(p.Dst.To4()) w.Write(p.Hash[:]) w.WriteUInt64(crc64.Checksum(w.Bytes(), crc64.MakeTable(crc64.ISO))) - w.Write(p.Data) + w.Write(p.Body()) }) } // FillHash 生成 p.Data 的 Hash func (p *Packet) FillHash() { h := blake2b.New256() - _, err := h.Write(p.Data) + _, err := h.Write(p.Body()) if err != nil { logrus.Error("[packet] err when fill hash:", err) return @@ -189,7 +197,7 @@ func (p *Packet) FillHash() { // IsVaildHash 验证 packet 合法性 func (p *Packet) IsVaildHash() bool { h := blake2b.New256() - _, err := h.Write(p.Data) + _, err := h.Write(p.Body()) if err != nil { logrus.Error("[packet] err when check hash:", err) return false @@ -219,3 +227,47 @@ func (p *Packet) Len() int { func (p *Packet) Put() { PutPacket(p) } + +// Body returns data +func (p *Packet) Body() []byte { + return p.data[p.a:p.b] +} + +func (p *Packet) BodyLen() int { + return p.b - p.a +} + +func (p *Packet) SetBody(b []byte, buffered bool) { + p.a = 0 + p.b = len(b) + if len(b) <= cap(p.data) { + p.data = p.data[:len(b)] + copy(p.data, b) + if buffered { + helper.PutBytes(b) + } + return + } + if p.buffered { + helper.PutBytes(p.data) + } + p.data = b + p.buffered = buffered +} + +func (p *Packet) CropBody(a, b int) { + if b > len(p.data) { + b = len(p.data) + } + if a < 0 || b < 0 || a > b { + return + } + p.a, p.b = a, b +} + +func (p *Packet) Copy() *Packet { + newp := SelectPacket() + *newp = *p + newp.buffered = false + return newp +} diff --git a/gold/head/pool.go b/gold/head/pool.go index 98763cb..4635136 100644 --- a/gold/head/pool.go +++ b/gold/head/pool.go @@ -1,6 +1,10 @@ package head -import "sync" +import ( + "sync" + + "github.com/fumiama/WireGold/helper" +) var packetPool = sync.Pool{ New: func() interface{} { @@ -16,6 +20,12 @@ func SelectPacket() *Packet { // PutPacket 将 Packet 放回池中 func PutPacket(p *Packet) { p.idxdatsz = 0 - p.Data = nil + if p.buffered { + helper.PutBytes(p.data) + p.buffered = false + } + p.a, p.b = 0, 0 + p.data = nil + p.rembytes = 0 packetPool.Put(p) } diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 810a85c..a393397 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -8,6 +8,7 @@ import ( "math/bits" mrand "math/rand" + "github.com/fumiama/WireGold/helper" "github.com/sirupsen/logrus" ) @@ -51,13 +52,13 @@ func expandkeyunit(v1, v2 byte) (v uint16) { return } -// Encode 使用 xchacha20poly1305 和密钥序列加密 +// Encode by aead and put b into pool func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { if len(b) == 0 || teatype >= 32 { return } if l.keys[0] == nil { - eb = make([]byte, len(b)) + eb = helper.MakeBytes(len(b)) copy(eb, b) return } @@ -70,13 +71,14 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { return } -// Decode 使用 xchacha20poly1305 和密钥序列解密 +// Decode by aead and put b into pool func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) { if len(b) == 0 || teatype >= 32 { return } if l.keys[0] == nil { - db = b + db = helper.MakeBytes(len(b)) + copy(db, b) return } aead := l.keys[teatype] @@ -86,11 +88,10 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, er return decode(aead, additional, b) } -// encode 使用 xchacha20poly1305 加密 func encode(aead cipher.AEAD, additional uint16, b []byte) []byte { nsz := aead.NonceSize() // Accocate capacity for all the stuffs. - buf := make([]byte, 2+nsz+len(b)+aead.Overhead()) + buf := helper.MakeBytes(2 + nsz + len(b) + aead.Overhead()) binary.LittleEndian.PutUint16(buf[:2], additional) nonce := buf[2 : 2+nsz] // Select a random nonce @@ -103,7 +104,6 @@ func encode(aead cipher.AEAD, additional uint16, b []byte) []byte { return nonce[:nsz+len(eb)] } -// decode 使用 xchacha20poly1305 解密 func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { nsz := aead.NonceSize() if len(b) < nsz { @@ -117,7 +117,7 @@ func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { // Decrypt the message and check it wasn't tampered with. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) - return aead.Open(nil, nonce, ciphertext, buf[:]) + return aead.Open(helper.SelectWriter().Bytes(), nonce, ciphertext, buf[:]) } // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data diff --git a/gold/link/listen.go b/gold/link/listen.go index ec60a71..9b00c01 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -15,8 +15,11 @@ import ( "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/p2p" + "github.com/fumiama/WireGold/helper" ) +const lstnbufgragsz = 65536 + // 监听本机 endpoint func (m *Me) listen() (conn p2p.Conn, err error) { conn, err = m.ep.Listen() @@ -34,7 +37,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) { n = 64 // 只用最多 64 核 } logrus.Infoln("[listen] use cpu num:", n) - listenbuff := make([]byte, 65536*n) + listenbuff := make([]byte, lstnbufgragsz*n) hasntfinished := make([]sync.Mutex, n) for i := 0; err == nil; i++ { i %= n @@ -46,7 +49,7 @@ func (m *Me) listen() (conn p2p.Conn, err error) { } } logrus.Debugln("[listen] lock index", i) - lbf := listenbuff[i*65536 : (i+1)*65536] + lbf := listenbuff[i*lstnbufgragsz : (i+1)*lstnbufgragsz] n, addr, err := conn.ReadFromPeer(lbf) if m.loop == nil || errors.Is(err, net.ErrClosed) { logrus.Warnln("[listen] quit listening") @@ -72,9 +75,9 @@ func (m *Me) listen() (conn p2p.Conn, err error) { recvtotlcnt = 0 recvlooptime = now } - packet := m.wait(lbf[:n]) + packet := m.wait(lbf[:n:lstnbufgragsz]) if packet == nil { - logrus.Debugln("[listen] unlock index", i) + logrus.Debugln("[listen] waiting, unlock index", i) hasntfinished[i].Unlock() i-- continue @@ -87,10 +90,11 @@ func (m *Me) listen() (conn p2p.Conn, err error) { func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish func()) { defer finish() - defer logrus.Debugln("[listen] unlock index", index) - r := packet.Len() - len(packet.Data) + defer logrus.Debugln("[listen] dispatched, unlock index", index) + logrus.Debugln("[listen] start dispatching index", index) + r := packet.Len() - packet.BodyLen() if r > 0 { - logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it") + logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "len", packet.BodyLen(), "is smaller than it declared len", packet.Len(), ", drop it") packet.Put() return } @@ -114,22 +118,25 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish } addt := packet.AdditionalData() var err error - packet.Data, err = p.Decode(packet.CipherIndex(), addt, packet.Data) + data, err := p.Decode(packet.CipherIndex(), addt, packet.Body()) if err != nil { logrus.Debugln("[listen] @", index, "drop invalid packet", ", key idx:", packet.CipherIndex(), "addt:", addt, "err:", err) packet.Put() return } + packet.SetBody(data, true) if p.usezstd { - dec, _ := zstd.NewReader(bytes.NewReader(packet.Data)) + dec, _ := zstd.NewReader(bytes.NewReader(packet.Body())) var err error - packet.Data, err = io.ReadAll(dec) + w := helper.SelectWriter() + _, err = io.Copy(w, dec) dec.Close() if err != nil { logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err) packet.Put() return } + packet.SetBody(w.Bytes(), true) } if !packet.IsVaildHash() { logrus.Debugln("[listen] @", index, "drop invalid hash packet") @@ -154,22 +161,22 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish packet.Put() case head.ProtoNotify: logrus.Infoln("[listen] @", index, "recv notify from", packet.Src) - go p.onNotify(packet.Data) + go p.onNotify(packet.Body()) packet.Put() case head.ProtoQuery: logrus.Infoln("[listen] @", index, "recv query from", packet.Src) - go p.onQuery(packet.Data) + go p.onQuery(packet.Body()) packet.Put() case head.ProtoData: if p.pipe != nil { p.pipe <- packet logrus.Debugln("[listen] @", index, "deliver to pipe of", p.peerip) } else { - _, err := m.nic.Write(packet.Data) + _, err := m.nic.Write(packet.Body()) if err != nil { - logrus.Errorln("[listen] @", index, "deliver", len(packet.Data), "bytes data to nic err:", err) + logrus.Errorln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic err:", err) } else { - logrus.Debugln("[listen] @", index, "deliver", len(packet.Data), "bytes data to nic") + logrus.Debugln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic") } packet.Put() } diff --git a/gold/link/send.go b/gold/link/send.go index df512e9..5167afa 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -32,32 +32,33 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { logrus.Warnln("[send] reset invalid data frag len", delta, "to 8") delta = 8 } - if len(p.Data) <= delta { - return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false) + remlen := p.BodyLen() + if remlen <= delta { + return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) } - if istransfer && p.Flags.DontFrag() && len(p.Data) > delta { + if istransfer && p.Flags.DontFrag() && remlen > delta { return 0, errors.New("drop don't fragmnet big trans packet") } - data := p.Data ttl := p.TTL - totl := uint32(len(data)) + totl := uint32(remlen) pos := 0 - packet := head.SelectPacket() - *packet = *p - for ; int(totl)-pos > delta; pos += delta { - logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", int(totl)-pos-delta) - packet.Data = data[:delta] + packet := p.Copy() + for remlen > delta { + remlen -= delta + logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen) + packet.CropBody(pos, pos+delta) cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true) n += cnt if err != nil { return n, err } - data = data[delta:] packet.TTL = ttl + pos += delta } packet.Put() - if len(data) > 0 { - p.Data = data + if remlen > 0 { + logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]") + p.CropBody(pos, pos+remlen) cnt := 0 cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false) n += cnt @@ -67,18 +68,19 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) { p.FillHash() - logrus.Debugln("[send] data len before encrypt:", len(p.Data)) + logrus.Debugln("[send] data len before encrypt:", p.BodyLen()) + data := p.Body() if l.usezstd { w := helper.SelectWriter() defer helper.PutWriter(w) enc, _ := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedFastest)) - _, _ = io.Copy(enc, bytes.NewReader(p.Data)) + _, _ = io.Copy(enc, bytes.NewReader(data)) enc.Close() - p.Data = w.Bytes() - logrus.Debugln("[send] data len after zstd:", len(p.Data)) + data = w.Bytes() + logrus.Debugln("[send] data len after zstd:", len(data)) } - p.Data = l.Encode(teatype, sndcnt&0x07ff, p.Data) - logrus.Debugln("[send] data len after xchacha20:", len(p.Data), "addt:", sndcnt) + p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data), true) + logrus.Debugln("[send] data len after xchacha20:", p.BodyLen(), "addt:", sndcnt) } // write 向 peer 发一个包 diff --git a/helper/pool.go b/helper/pool.go index c2f4558..7261bfe 100644 --- a/helper/pool.go +++ b/helper/pool.go @@ -13,6 +13,19 @@ var bufferPool = sync.Pool{ }, } +func MakeBytes(sz int) []byte { + w := SelectWriter() + b := w.Bytes() + if cap(b) >= sz { + return b[:sz] + } + return make([]byte, sz) +} + +func PutBytes(b []byte) { + PutWriter((*Writer)(bytes.NewBuffer(b))) +} + // SelectWriter 从池中取出一个 Writer func SelectWriter() *Writer { // 因为 bufferPool 定义有 New 函数 diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 4670d31..7c40bfb 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -70,11 +70,11 @@ func (s *Tunnel) Read(p []byte) (int, error) { return 0, io.EOF } defer pkt.Put() - if len(pkt.Data) < 4 { - logrus.Warnln("[tunnel] unexpected packet data len", len(pkt.Data), "content", pkt.Data) + if pkt.BodyLen() < 4 { + logrus.Warnln("[tunnel] unexpected packet data len", pkt.BodyLen(), "content", hex.EncodeToString(pkt.Body())) return 0, io.EOF } - d = pkt.Data[4:] + d = pkt.Body()[4:] } if d != nil { if len(p) >= len(d) { @@ -111,7 +111,7 @@ func (s *Tunnel) handleWrite() { } logrus.Debugln("[tunnel] writing", len(b), "bytes...") for len(b) > int(s.mtu)-4 { - logrus.Infoln("[tunnel] seq", seq, "split buffer") + logrus.Debugln("[tunnel] seq", seq, "split buffer") binary.LittleEndian.PutUint32(buf[:4], seq) seq++ copy(buf[4:], b[:s.mtu-4]) @@ -157,12 +157,12 @@ func (s *Tunnel) handleRead() { } end := 64 endl := "..." - if len(p.Data) < 64 { - end = len(p.Data) + if p.BodyLen() < 64 { + end = p.BodyLen() endl = "." } - logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Data[:end]), endl) - recvseq := binary.LittleEndian.Uint32(p.Data[:4]) + logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Body()[:end]), endl) + recvseq := binary.LittleEndian.Uint32(p.Body()[:4]) if recvseq == seq { logrus.Debugln("[tunnel] dispatch seq", seq) seq++ diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 780f780..89434d1 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -169,8 +169,7 @@ type logFormat struct { // Format implements logrus.Formatter func (f logFormat) Format(entry *logrus.Entry) ([]byte, error) { - buf := helper.SelectWriter() - defer helper.PutWriter(buf) + buf := helper.SelectWriter() // this writer will not be put back buf.WriteByte('[') if f.enableColor { @@ -184,9 +183,7 @@ func (f logFormat) Format(entry *logrus.Entry) ([]byte, error) { buf.WriteString(entry.Message) buf.WriteString("\n") - ret := make([]byte, len(buf.Bytes())) - copy(ret, buf.Bytes()) // copy buffer - return ret, nil + return buf.Bytes(), nil } const (