1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-23 03:50:32 +08:00

optimize(all): drop lstnq & impl. orbyte

This commit is contained in:
源文雨
2025-02-25 19:38:16 +09:00
parent 4b60801a0f
commit 9f36504635
22 changed files with 501 additions and 573 deletions

View File

@@ -7,8 +7,9 @@ import (
"errors"
"math/bits"
mrand "math/rand"
"runtime"
"github.com/fumiama/WireGold/helper"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
)
@@ -53,14 +54,12 @@ func expandkeyunit(v1, v2 byte) (v uint16) {
}
// Encode by aead and put b into pool
func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb pbuf.Bytes) {
if len(b) == 0 || teatype >= 32 {
return
}
if l.keys[0] == nil {
eb = helper.MakeBytes(len(b))
copy(eb, b)
return
return pbuf.ParseBytes(b...)
}
aead := l.keys[teatype]
if aead == nil {
@@ -72,14 +71,12 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
}
// Decode by aead and put b into pool
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) {
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db pbuf.Bytes, err error) {
if len(b) == 0 || teatype >= 32 {
return
}
if l.keys[0] == nil {
db = helper.MakeBytes(len(b))
copy(db, b)
return
return pbuf.ParseBytes(b...), nil
}
aead := l.keys[teatype]
if aead == nil {
@@ -88,59 +85,67 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, er
return decode(aead, additional, b)
}
func encode(aead cipher.AEAD, additional uint16, b []byte) []byte {
func encode(aead cipher.AEAD, additional uint16, b []byte) pbuf.Bytes {
nsz := aead.NonceSize()
// Accocate capacity for all the stuffs.
buf := helper.MakeBytes(2 + nsz + len(b) + aead.Overhead())
binary.LittleEndian.PutUint16(buf[:2], additional)
nonce := buf[2 : 2+nsz]
buf := pbuf.NewBytes(2 + nsz + len(b) + aead.Overhead())
binary.LittleEndian.PutUint16(buf.Bytes()[:2], additional)
nonce := buf.Bytes()[2 : 2+nsz]
// Select a random nonce
_, err := rand.Read(nonce)
if err != nil {
panic(err)
}
// Encrypt the message and append the ciphertext to the nonce.
eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf[:2])
return nonce[:nsz+len(eb)]
eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf.Bytes()[:2])
return buf.Trans().Slice(2, 2+nsz+len(eb))
}
func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
func decode(aead cipher.AEAD, additional uint16, b []byte) (pbuf.Bytes, error) {
nsz := aead.NonceSize()
if len(b) < nsz {
return nil, ErrCipherTextTooShort
return pbuf.Bytes{}, ErrCipherTextTooShort
}
// Split nonce and ciphertext.
nonce, ciphertext := b[:nsz], b[nsz:]
if len(ciphertext) == 0 {
return nil, nil
return pbuf.Bytes{}, nil
}
// Decrypt the message and check it wasn't tampered with.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
return aead.Open(helper.SelectWriter().Bytes(), nonce, ciphertext, buf[:])
data, err := aead.Open(
pbuf.NewBytes(4096).Trans().Bytes()[:0],
nonce, ciphertext, buf[:],
)
if err != nil {
return pbuf.Bytes{}, nil
}
return pbuf.ParseBytes(data...), nil
}
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data
func (m *Me) xorenc(data []byte, seq uint32) []byte {
func (m *Me) xorenc(data []byte, seq uint32) pbuf.Bytes {
batchsz := len(data) / 8
remain := len(data) % 8
sum := m.mask
newdat := helper.MakeBytes(8 + batchsz*8 + 8) // seqrand dat tail
binary.LittleEndian.PutUint32(newdat[:4], seq)
_, _ = rand.Read(newdat[4:8]) // seqrand
sum ^= binary.LittleEndian.Uint64(newdat[:8]) // init from seqrand
binary.LittleEndian.PutUint64(newdat[:8], sum)
newdat := pbuf.NewBytes(8 + batchsz*8 + 8) // seqrand dat tail
binary.LittleEndian.PutUint32(newdat.Bytes()[:4], seq)
_, _ = rand.Read(newdat.Bytes()[4:8]) // seqrand
sum ^= binary.LittleEndian.Uint64(newdat.Bytes()[:8]) // init from seqrand
binary.LittleEndian.PutUint64(newdat.Bytes()[:8], sum)
for i := 0; i < batchsz; i++ { // range on batch data
a := i * 8
b := (i + 1) * 8
sum ^= binary.LittleEndian.Uint64(data[a:b])
binary.LittleEndian.PutUint64(newdat[a+8:b+8], sum)
binary.LittleEndian.PutUint64(newdat.Bytes()[a+8:b+8], sum)
}
p := batchsz * 8
copy(newdat[8+p:], data[p:])
newdat[len(newdat)-1] = byte(remain)
sum ^= binary.LittleEndian.Uint64(newdat[8+p:])
binary.LittleEndian.PutUint64(newdat[8+p:], sum)
copy(newdat.Bytes()[8+p:], data[p:])
runtime.KeepAlive(data)
newdat.Bytes()[newdat.Len()-1] = byte(remain)
sum ^= binary.LittleEndian.Uint64(newdat.Bytes()[8+p:])
binary.LittleEndian.PutUint64(newdat.Bytes()[8+p:], sum)
return newdat
}
@@ -163,5 +168,6 @@ func (m *Me) xordec(data []byte) (uint32, []byte) {
if remain >= 8 {
return 0, nil
}
return binary.LittleEndian.Uint32(data[:4]), data[8 : len(data)-8+int(remain)]
return binary.LittleEndian.Uint32(data[:4]),
data[8 : len(data)-8+int(remain)]
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"encoding/hex"
"io"
"runtime"
"testing"
"golang.org/x/crypto/chacha20poly1305"
@@ -27,10 +28,11 @@ func TestXOR(t *testing.T) {
if err != nil {
t.Fatal(err)
}
seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i)))
seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i)).Trans().Bytes())
if !bytes.Equal(dec, r2.Bytes()) {
t.Fatal("unexpected xor at", i, "except", hex.EncodeToString(r2.Bytes()), "got", hex.EncodeToString(dec))
}
runtime.KeepAlive(dec)
if seq != uint32(i) {
t.Fatal("unexpected xor at", i, "seq", seq)
}
@@ -53,11 +55,11 @@ func TestXChacha20(t *testing.T) {
t.Fatal(err)
}
for i := 0; i < 4096; i++ {
db, err := decode(aead, uint16(i), encode(aead, uint16(i), data[:i]))
db, err := decode(aead, uint16(i), encode(aead, uint16(i), data[:i]).Trans().Bytes())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(db, data[:i]) {
if !bytes.Equal(db.Bytes(), data[:i]) {
t.Fatal("unexpected preshared at idx(len)", i, "addt", uint16(i))
}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/fumiama/WireGold/gold/p2p"
"github.com/fumiama/WireGold/helper"
base14 "github.com/fumiama/go-base16384"
"github.com/fumiama/orbyte"
)
var (
@@ -26,7 +27,7 @@ type Link struct {
// 收到的包的队列
// 没有下层 nic 时
// 包会分发到此
pipe chan *head.Packet
pipe chan *orbyte.Item[head.Packet]
// peer 的虚拟 ip
peerip net.IP
// peer 的公网 endpoint

View File

@@ -19,56 +19,12 @@ import (
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/gold/p2p"
"github.com/fumiama/WireGold/helper"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
)
const lstnbufgragsz = 65536
type lstnq struct {
index int
addr p2p.EndPoint
buf []byte
}
type listenqueue chan lstnq
func (q listenqueue) listen(m *Me, hasntfinished []sync.Mutex) {
recvtotlcnt := uint64(0)
recvloopcnt := uint16(0)
recvlooptime := time.Now().UnixMilli()
for lstn := range q {
recvtotlcnt += uint64(len(lstn.buf))
recvloopcnt++
if recvloopcnt%m.speedloop == 0 {
now := time.Now().UnixMilli()
logrus.Infof("[listen] queue recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime))
recvtotlcnt = 0
recvlooptime = now
}
packet := m.wait(lstn.buf[:len(lstn.buf):lstnbufgragsz])
if packet == nil {
if lstn.index < 0 {
if config.ShowDebugLog {
logrus.Debugln("[listen] queue waiting")
}
helper.PutBytes(lstn.buf)
continue
}
if config.ShowDebugLog {
logrus.Debugln("[listen] queue waiting, unlock index", lstn.index)
}
hasntfinished[lstn.index].Unlock()
continue
}
if lstn.index >= 0 {
go m.dispatch(packet, lstn.addr, lstn.index, hasntfinished[lstn.index].Unlock)
} else {
go m.dispatch(packet, lstn.addr, lstn.index, func() {
helper.PutBytes(lstn.buf)
})
}
}
}
// 监听本机 endpoint
func (m *Me) listen() (conn p2p.Conn, err error) {
conn, err = m.ep.Listen()
@@ -85,9 +41,6 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
logrus.Infoln("[listen] use cpu num:", n)
listenbuf := make([]byte, lstnbufgragsz*n)
hasntfinished := make([]sync.Mutex, n)
q := make(listenqueue, n)
defer close(q)
go q.listen(m, hasntfinished)
for {
usenewbuf := false
i := uint(0)
@@ -105,13 +58,16 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
if config.ShowDebugLog && !usenewbuf {
logrus.Debugln("[listen] lock index", i)
}
var lbf []byte
var lbf pbuf.Bytes
if usenewbuf {
lbf = helper.MakeBytes(lstnbufgragsz)
lbf = pbuf.NewBytes(lstnbufgragsz)
} else {
lbf = listenbuf[i*lstnbufgragsz : (i+1)*lstnbufgragsz]
if config.ShowDebugLog {
logrus.Debugln("[listen] take index", i, "slice", i*lstnbufgragsz, (i+1)*lstnbufgragsz, "cap", lstnbufgragsz)
}
lbf = pbuf.ParseBytes(listenbuf[i*lstnbufgragsz : (i+1)*lstnbufgragsz : (i+1)*lstnbufgragsz]...)
}
n, addr, err := conn.ReadFromPeer(lbf)
n, addr, err := conn.ReadFromPeer(lbf.Bytes())
if m.connections == nil || errors.Is(err, net.ErrClosed) {
logrus.Warnln("[listen] quit listening")
return
@@ -138,39 +94,70 @@ func (m *Me) listen() (conn p2p.Conn, err error) {
}
continue
}
lq := lstnq{
index: -1,
addr: addr,
buf: lbf[:n],
}
index := -1
if !usenewbuf {
lq.index = int(i)
index = int(i)
}
q <- lq
go m.waitordispatch(index, addr, lbf.Trans().SliceTo(n), hasntfinished)
}
}()
return
}
func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish func()) {
defer finish()
func (m *Me) waitordispatch(index int, addr p2p.EndPoint, buf pbuf.Bytes, hasntfinished []sync.Mutex) {
recvtotlcnt := atomic.AddUint64(&m.recvtotlcnt, uint64(buf.Len()))
recvloopcnt := atomic.AddUintptr(&m.recvloopcnt, 1)
recvlooptime := atomic.LoadInt64(&m.recvlooptime)
if recvloopcnt%uintptr(m.speedloop) == 0 {
now := time.Now().UnixMilli()
logrus.Infof("[listen] queue recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime))
atomic.StoreUint64(&m.recvtotlcnt, 0)
atomic.StoreInt64(&m.recvlooptime, now)
}
packet := m.wait(buf.Trans())
if packet == nil {
if index < 0 {
if config.ShowDebugLog {
logrus.Debugln("[listen] queue waiting")
}
return
}
if config.ShowDebugLog {
logrus.Debugln("[listen] queue waiting, unlock index", index)
}
hasntfinished[index].Unlock()
return
}
if config.ShowDebugLog {
logrus.Debugln("[listen] index", index, "dispatch", len(packet.Pointer().Body()), "bytes packet")
}
if index >= 0 {
defer hasntfinished[index].Unlock()
m.dispatch(packet, addr, index)
return
}
m.dispatch(packet, addr, index)
}
func (m *Me) dispatch(packet *orbyte.Item[head.Packet], addr p2p.EndPoint, index int) {
defer runtime.KeepAlive(packet)
if config.ShowDebugLog {
defer logrus.Debugln("[listen] dispatched, unlock index", index)
logrus.Debugln("[listen] start dispatching index", index)
}
r := packet.Len() - packet.BodyLen()
pp := packet.Pointer()
r := pp.Len() - pp.BodyLen()
if r > 0 {
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "len", packet.BodyLen(), "is smaller than it declared len", packet.Len(), ", drop it")
packet.Put()
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "len", pp.BodyLen(), "is smaller than it declared len", pp.Len(), ", drop it")
return
}
p, ok := m.IsInPeer(packet.Src.String())
p, ok := m.IsInPeer(pp.Src.String())
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", pp.Src, "dst", pp.Dst)
}
if !ok {
logrus.Warnln("[listen] @", index, "packet from", packet.Src, "to", packet.Dst, "is refused")
packet.Put()
logrus.Warnln("[listen] @", index, "packet from", pp.Src, "to", pp.Dst, "is refused")
return
}
if helper.IsNilInterface(p.endpoint) || !p.endpoint.Euqal(addr) {
@@ -185,25 +172,23 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
now := time.Now()
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.lastalive)), unsafe.Pointer(&now))
switch {
case p.IsToMe(packet.Dst):
if !p.Accept(packet.Src) {
logrus.Warnln("[listen] @", index, "refused packet from", packet.Src.String()+":"+strconv.Itoa(int(packet.SrcPort)))
packet.Put()
case p.IsToMe(pp.Dst):
if !p.Accept(pp.Src) {
logrus.Warnln("[listen] @", index, "refused packet from", pp.Src.String()+":"+strconv.Itoa(int(pp.SrcPort)))
return
}
addt := packet.AdditionalData()
addt := pp.AdditionalData()
var err error
data, err := p.Decode(packet.CipherIndex(), addt, packet.Body())
data, err := p.Decode(pp.CipherIndex(), addt, pp.Body())
if err != nil {
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", packet.CipherIndex(), "addt:", addt, "err:", err)
logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", pp.CipherIndex(), "addt:", addt, "err:", err)
}
packet.Put()
return
}
packet.SetBody(data, true)
pp.SetBody(data.Trans().Bytes())
if p.usezstd {
dec, _ := zstd.NewReader(bytes.NewReader(packet.Body()))
dec, _ := zstd.NewReader(bytes.NewReader(pp.Body()))
var err error
w := helper.SelectWriter()
_, err = io.Copy(w, dec)
@@ -212,25 +197,27 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err)
}
packet.Put()
return
}
packet.SetBody(w.Bytes(), true)
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "zstd decoded len:", w.Len())
}
pp.SetBody(w.TransBytes())
}
if !packet.IsVaildHash() {
if !pp.IsVaildHash() {
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "drop invalid hash packet")
}
packet.Put()
return
}
switch packet.Proto {
switch pp.Proto {
case head.ProtoHello:
switch {
case len(packet.Body()) == 0:
case len(pp.Body()) == 0:
logrus.Warnln("[listen] @", index, "recv old hello packet, do nothing")
case packet.Body()[0] == byte(head.HelloPing):
n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), []byte{byte(head.HelloPong)}), false)
case pp.Body()[0] == byte(head.HelloPing):
n, err := p.WritePacket(head.NewPacketPartial(
head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), pbuf.ParseBytes(byte(head.HelloPong))), false)
if err == nil {
logrus.Infoln("[listen] @", index, "recv hello, send", n, "bytes hello ack packet")
} else {
@@ -239,57 +226,49 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
default:
logrus.Infoln("[listen] @", index, "recv hello ack packet, do nothing")
}
packet.Put()
case head.ProtoNotify:
logrus.Infoln("[listen] @", index, "recv notify from", packet.Src)
go p.onNotify(packet.Body())
packet.Put()
logrus.Infoln("[listen] @", index, "recv notify from", pp.Src)
p.onNotify(pp.Body())
case head.ProtoQuery:
logrus.Infoln("[listen] @", index, "recv query from", packet.Src)
go p.onQuery(packet.Body())
packet.Put()
logrus.Infoln("[listen] @", index, "recv query from", pp.Src)
p.onQuery(pp.Body())
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet.CopyWithBody()
p.pipe <- packet.Copy()
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "deliver to pipe of", p.peerip)
}
} else {
_, err := m.nic.Write(packet.Body())
_, err := m.nic.Write(pp.Body())
if err != nil {
logrus.Errorln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic err:", err)
logrus.Errorln("[listen] @", index, "deliver", pp.BodyLen(), "bytes data to nic err:", err)
} else if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "deliver", packet.BodyLen(), "bytes data to nic")
logrus.Debugln("[listen] @", index, "deliver", pp.BodyLen(), "bytes data to nic")
}
}
packet.Put()
default:
logrus.Warnln("[listen] @", index, "recv unknown proto:", packet.Proto)
packet.Put()
logrus.Warnln("[listen] @", index, "recv unknown proto:", pp.Proto)
}
case p.Accept(packet.Dst):
case p.Accept(pp.Dst):
if !p.allowtrans {
logrus.Warnln("[listen] @", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
logrus.Warnln("[listen] @", index, "refused to trans packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)))
return
}
// 转发
lnk := m.router.NextHop(packet.Dst.String())
lnk := m.router.NextHop(pp.Dst.String())
if lnk == nil {
logrus.Warnln("[listen] @", index, "transfer drop packet: nil nexthop")
packet.Put()
return
}
n, err := lnk.WriteAndPut(packet, true)
n, err := lnk.WritePacket(packet, true)
if err == nil {
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)))
}
} else {
logrus.Errorln("[listen] @", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
logrus.Errorln("[listen] @", index, "trans packet to", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)), "err:", err)
}
default:
logrus.Warnln("[listen] @", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
packet.Put()
logrus.Warnln("[listen] @", index, "packet dst", pp.Dst.String()+":"+strconv.Itoa(int(pp.DstPort)), "is not in peers")
}
}

View File

@@ -10,6 +10,8 @@ import (
"time"
"github.com/FloatTech/ttl"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
"github.com/fumiama/water/waterutil"
"github.com/sirupsen/logrus"
@@ -46,13 +48,19 @@ type Me struct {
// 本机路由表
router *Router
// 本机未接收完全分片池
recving *ttl.Cache[uint64, *head.Packet]
recving *ttl.Cache[uint64, *orbyte.Item[head.Packet]]
// 抗重放攻击记录池
recved *ttl.Cache[uint64, bool]
recved *ttl.Cache[uint64, struct{}]
// 本机上层配置
srcport, dstport, mtu, speedloop uint16
// 报头掩码
mask uint64
// 本机总接收字节数
recvtotlcnt uint64
// 上一次触发循环计数时间
recvlooptime int64
// 本机总接收数据包计数
recvloopcnt uintptr
// 是否进行 base16384 编码
base14 bool
// 本机网络端点初始化配置
@@ -122,12 +130,13 @@ func NewMe(cfg *MyConfig) (m Me) {
)
}
m.mask = cfg.Mask
m.recvlooptime = time.Now().UnixMilli()
m.base14 = cfg.Base14
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], m.mask)
logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:]))
m.recving = ttl.NewCache[uint64, *head.Packet](time.Second * 30)
m.recved = ttl.NewCache[uint64, bool](time.Second * 30)
m.recving = ttl.NewCache[uint64, *orbyte.Item[head.Packet]](time.Second * 10)
m.recved = ttl.NewCache[uint64, struct{}](time.Minute)
return
}
@@ -154,6 +163,7 @@ func (m *Me) Restart() error {
}
m.me = ip
m.subnet = *cidr
m.recvlooptime = time.Now().UnixMilli()
m.conn, err = m.listen()
return err
}
@@ -280,11 +290,10 @@ func (m *Me) sendAllSameDst(packet []byte) (n int) {
logrus.Warnln("[me] drop packet to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), ": nil nexthop")
return
}
pcp := helper.MakeBytes(len(packet))
copy(pcp, packet)
go func(packet []byte) {
defer helper.PutBytes(packet)
_, err := lnk.WriteAndPut(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false)
pcp := pbuf.NewBytes(len(packet))
copy(pcp.Bytes(), packet)
go func(packet pbuf.Bytes) {
_, err := lnk.WritePacket(head.NewPacketPartial(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false)
if err != nil {
logrus.Warnln("[me] write to peer", lnk.peerip, "err:", err)
}

View File

@@ -1,6 +1,7 @@
package link
import (
"bytes"
"encoding/json"
"sync/atomic"
"time"
@@ -12,6 +13,8 @@ import (
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/gold/p2p"
"github.com/fumiama/WireGold/helper"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
)
// 保持 NAT
@@ -35,7 +38,7 @@ func (l *Link) keepAlive(dur int64) {
logrus.Infoln("[nat] re-connect me succeeded")
}
}
n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, []byte{byte(head.HelloPing)}), false)
n, err := l.WritePacket(head.NewPacketPartial(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, pbuf.ParseBytes(byte(head.HelloPing))), false)
if err == nil {
logrus.Infoln("[nat] send", n, "bytes keep alive packet")
} else {
@@ -131,12 +134,14 @@ func (l *Link) onQuery(packet []byte) {
logrus.Infoln("[nat] query wrap", len(notify), "notify")
w := helper.SelectWriter()
_ = json.NewEncoder(w).Encode(&notify)
_, err = l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false)
_, err = l.WritePacket(head.NewPacketPartial(
head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport,
pbuf.BufferItemToBytes((*orbyte.Item[bytes.Buffer])(w).Trans()),
), false)
if err != nil {
logrus.Errorln("[nat] notify peer", l, "err:", err)
return
}
helper.PutWriter(w)
}
}
@@ -152,7 +157,10 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) {
t := time.NewTicker(tick)
for range t.C {
logrus.Infoln("[nat] query send query to", l.peerip)
_, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false)
_, err = l.WritePacket(head.NewPacketPartial(
head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport,
pbuf.ParseBytes(data...),
), false)
if err != nil {
logrus.Errorln("[nat] query write err:", err)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/gold/p2p"
curve "github.com/fumiama/go-x25519"
"github.com/fumiama/orbyte"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/chacha20poly1305"
)
@@ -49,7 +50,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
}
if !cfg.NoPipe {
l.pipe = make(chan *head.Packet, 32)
l.pipe = make(chan *orbyte.Item[head.Packet], 65536)
}
var k, p []byte
if cfg.PubicKey != nil {

View File

@@ -9,129 +9,123 @@ import (
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/gold/head"
base14 "github.com/fumiama/go-base16384"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
"github.com/sirupsen/logrus"
)
// Read 从 peer 收包
func (l *Link) Read() *head.Packet {
func (l *Link) Read() *orbyte.Item[head.Packet] {
return <-l.pipe
}
func (m *Me) wait(data []byte) *head.Packet {
if len(data) < head.PacketHeadLen { // not a valid packet
func (m *Me) wait(data pbuf.Bytes) *orbyte.Item[head.Packet] {
if data.Len() < head.PacketHeadLen { // not a valid packet
if config.ShowDebugLog {
logrus.Debugln("[recv] invalid data len", len(data))
logrus.Debugln("[recv] invalid data len", data.Len())
}
return nil
}
bound := 64
endl := "..."
if len(data) < bound {
bound = len(data)
if data.Len() < bound {
bound = data.Len()
endl = "."
}
if config.ShowDebugLog {
logrus.Debugln("[recv] data bytes, len", len(data), "val", hex.EncodeToString(data[:bound]), endl)
logrus.Debugln("[recv] data bytes, len", data.Len(), "val", hex.EncodeToString(data.Bytes()[:bound]), endl)
}
if m.base14 {
data = base14.Decode(data)
if len(data) < bound {
bound = len(data)
data = pbuf.ParseBytes(base14.Decode(data.Bytes())...)
if data.Len() < bound {
bound = data.Len()
endl = "."
}
if config.ShowDebugLog {
logrus.Debugln("[recv] data b14ed, len", len(data), "val", hex.EncodeToString(data[:bound]), endl)
logrus.Debugln("[recv] data b14ed, len", data.Len(), "val", hex.EncodeToString(data.Bytes()[:bound]), endl)
}
if len(data) < head.PacketHeadLen { // not a valid packet
if data.Len() < head.PacketHeadLen { // not a valid packet
if config.ShowDebugLog {
logrus.Debugln("[recv] invalid data len", len(data))
logrus.Debugln("[recv] invalid data len", data.Len())
}
return nil
}
}
seq, data := m.xordec(data)
if len(data) < bound {
bound = len(data)
seq, dat := m.xordec(data.Trans().Bytes())
if len(dat) < bound {
bound = len(dat)
endl = "."
}
if config.ShowDebugLog {
logrus.Debugln("[recv] data xored, len", len(data), "val", hex.EncodeToString(data[:bound]), endl)
logrus.Debugln("[recv] data xored, len", len(dat), "val", hex.EncodeToString(dat[:bound]), endl)
}
if len(data) < head.PacketHeadLen { // not a valid packet
header, err := head.ParsePacketHeader(dat)
if err != nil { // not a valid packet
if config.ShowDebugLog {
logrus.Debugln("[recv] invalid data len", len(data))
logrus.Debugln("[recv] invalid packet header:", err)
}
return nil
}
flags := head.Flags(data)
if !flags.IsValid() {
if !header.Pointer().Flags.IsValid() {
if config.ShowDebugLog {
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
logrus.Debugln("[recv] drop invalid flags packet:", header.Pointer().Flags)
}
return nil
}
crc := head.CRC64(data)
crc := header.Pointer().CRC64()
crclog := crc
crc ^= (uint64(seq) << 16)
if config.ShowDebugLog {
logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc)
}
if m.recved.Get(crc) {
if _, got := m.recved.GetOrSet(crc, struct{}{}); got {
if config.ShowDebugLog {
logrus.Debugln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
}
return nil
}
m.recved.Set(crc, true)
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
logrus.Debugln(
"[recv]", strconv.FormatUint(crc, 16),
len(dat), "bytes data with flag", header.Pointer().Flags,
"offset", header.Pointer().Flags.Offset(),
)
}
if flags.IsSingle() || flags.NoFrag() {
h := head.SelectPacket()
_, err := h.Unmarshal(data)
if err != nil {
logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err)
if header.Pointer().Flags.IsSingle() || header.Pointer().Flags.NoFrag() {
ok := header.Pointer().ParseData(dat)
if !ok {
logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unexpected !ok")
return nil
}
return h
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), len(dat), "bytes full data waited")
}
return header
}
crchash := crc64.New(crc64.MakeTable(crc64.ISO))
_, _ = crchash.Write(head.Hash(data))
_, _ = crchash.Write(head.Hash(data.Bytes()))
var buf [4]byte
binary.LittleEndian.PutUint32(buf[:], seq)
_, _ = crchash.Write(buf[:])
hsh := crchash.Sum64()
h := m.recving.Get(hsh)
if h != nil {
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get another frag part of", strconv.FormatUint(hsh, 16))
}
ok, err := h.Unmarshal(data)
if err == nil {
if ok {
m.recving.Delete(hsh)
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "all parts of", strconv.FormatUint(hsh, 16), "has reached")
}
return h
}
} else {
h.Put()
logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err)
}
return nil
h, got := m.recving.GetOrSet(hsh, header)
if got && h == header {
panic("unexpected multi-put found")
}
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get new frag part of", strconv.FormatUint(hsh, 16))
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "get frag part of", strconv.FormatUint(hsh, 16), "isnew:", !got)
}
h = head.SelectPacket()
_, err := h.Unmarshal(data)
if err != nil {
h.Put()
logrus.Errorln("[recv]", strconv.FormatUint(crc, 16), "unmarshal err:", err)
ok := h.Pointer().ParseData(dat)
if !ok {
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "wait other frag parts of", strconv.FormatUint(hsh, 16), "isnew:", !got)
}
return nil
}
m.recving.Set(hsh, h)
return nil
m.recving.Delete(hsh)
if config.ShowDebugLog {
logrus.Debugln("[recv]", strconv.FormatUint(crc, 16), "all parts of", strconv.FormatUint(hsh, 16), "has reached")
}
return h
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"math/rand"
"runtime"
"github.com/klauspost/compress/zstd"
"github.com/sirupsen/logrus"
@@ -17,6 +18,8 @@ import (
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/helper"
base14 "github.com/fumiama/go-base16384"
"github.com/fumiama/orbyte"
"github.com/fumiama/orbyte/pbuf"
)
var (
@@ -24,15 +27,19 @@ var (
ErrTTL = errors.New("ttl exceeded")
)
// WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
defer p.Put()
teatype := l.randkeyidx()
sndcnt := uint16(l.incgetsndcnt())
func randseq(i uint16) uint32 {
var buf [4]byte
_, _ = crand.Read(buf[:2])
binary.BigEndian.PutUint16(buf[2:4], sndcnt)
seq := binary.BigEndian.Uint32(buf[:])
binary.BigEndian.PutUint16(buf[2:4], i)
return binary.BigEndian.Uint32(buf[:])
}
// WritePacket 向 peer 发包
func (l *Link) WritePacket(p *orbyte.Item[head.Packet], istransfer bool) (n int, err error) {
pp := p.Pointer()
teatype := l.randkeyidx()
sndcnt := uint16(l.incgetsndcnt())
seq := randseq(sndcnt)
mtu := l.mtu
if l.mturandomrange > 0 {
mtu -= uint16(rand.Intn(int(l.mturandomrange)))
@@ -41,48 +48,48 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
logrus.Debugln("[send] mtu:", mtu, ", addt:", sndcnt&0x07ff, ", key index:", teatype)
}
if !istransfer {
l.encrypt(p, sndcnt, teatype)
l.encrypt(pp, sndcnt, teatype)
}
delta := (int(mtu) - head.PacketHeadLen) & 0x0000fff8
if delta <= 0 {
logrus.Warnln("[send] reset invalid data frag len", delta, "to 8")
delta = 8
}
remlen := p.BodyLen()
remlen := pp.BodyLen()
if remlen <= delta {
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false, seq)
}
if istransfer && p.Flags.DontFrag() && remlen > delta {
if istransfer && pp.Flags.DontFrag() && remlen > delta {
return 0, ErrDropBigDontFragPkt
}
ttl := p.TTL
ttl := pp.TTL
totl := uint32(remlen)
pos := 0
packet := p.Copy()
packet := head.ParsePacket(pp.ShallowCopy())
for remlen > delta {
remlen -= delta
if config.ShowDebugLog {
logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen)
}
packet.CropBody(pos, pos+delta)
packet.Pointer().CropBody(pos, pos+delta)
cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true, seq)
n += cnt
if err != nil {
return n, err
}
packet.TTL = ttl
packet.Pointer().TTL = ttl
pos += delta
}
packet.Put()
if remlen > 0 {
if config.ShowDebugLog {
logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]")
}
p.CropBody(pos, pos+remlen)
pp.CropBody(pos, pos+remlen)
cnt := 0
cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false, seq)
n += cnt
}
runtime.KeepAlive(p)
return n, err
}
@@ -94,24 +101,27 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) {
data := p.Body()
if l.usezstd {
w := helper.SelectWriter()
defer helper.PutWriter(w)
enc, _ := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedFastest))
_, _ = io.Copy(enc, bytes.NewReader(data))
enc.Close()
data = w.Bytes()
data = w.TransBytes()
if config.ShowDebugLog {
logrus.Debugln("[send] data len after zstd:", len(data))
}
}
p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data), true)
p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data).Trans().Bytes())
if config.ShowDebugLog {
logrus.Debugln("[send] data len after xchacha20:", p.BodyLen(), "addt:", sndcnt)
}
}
// write 向 peer 发包
func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) {
if p.DecreaseAndGetTTL() <= 0 {
func (l *Link) write(
p *orbyte.Item[head.Packet], teatype uint8, additional uint16,
datasz uint32, offset uint16, istransfer,
hasmore bool, seq uint32,
) (int, error) {
if p.Pointer().DecreaseAndGetTTL() <= 0 {
return 0, ErrTTL
}
if l.doublepacket {
@@ -121,26 +131,28 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
}
// write 向 peer 发一个包
func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) {
func (l *Link) writeonce(
p *orbyte.Item[head.Packet], teatype uint8, additional uint16,
datasz uint32, offset uint16,
istransfer, hasmore bool, seq uint32,
) (int, error) {
peerep := l.endpoint
if helper.IsNilInterface(peerep) {
return 0, errors.New("nil endpoint of " + p.Dst.String())
return 0, errors.New("nil endpoint of " + p.Pointer().Dst.String())
}
var d []byte
var cl func()
var d pbuf.Bytes
// TODO: now all packet allow frag, adapt to DF
if istransfer {
d, cl = p.Marshal(nil, 0, 0, 0, offset, false, hasmore)
d = p.Pointer().MarshalWith(nil, 0, 0, 0, offset, false, hasmore)
} else {
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
d = p.Pointer().MarshalWith(l.me.me, teatype, additional, datasz, offset, false, hasmore)
}
defer cl()
bound := 64
endl := "..."
if len(d) < bound {
bound = len(d)
if d.Len() < bound {
bound = d.Len()
endl = "."
}
conn := l.me.conn
@@ -148,17 +160,15 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas
return 0, io.ErrClosedPipe
}
if config.ShowDebugLog {
logrus.Debugln("[send] write", len(d), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64()))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
logrus.Debugln("[send] write", d.Len(), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.Pointer().CRC64()))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d.Bytes()[:bound]), endl)
}
d = l.me.xorenc(d, seq)
d = l.me.xorenc(d.Bytes(), seq)
if l.me.base14 {
d = base14.Encode(d)
d = pbuf.ParseBytes(base14.Encode(d.Bytes())...)
}
if config.ShowDebugLog {
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
logrus.Debugln("[send] data xored", hex.EncodeToString(d.Bytes()[:bound]), endl)
}
defer helper.PutBytes(d)
return conn.WriteToPeer(d, peerep)
return conn.WriteToPeer(d.Trans().Bytes(), peerep)
}