mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-04 23:40:26 +08:00
optimize(link): pack zstd en/decoding
This commit is contained in:
2
go.mod
2
go.mod
@@ -8,7 +8,7 @@ require (
|
||||
github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb
|
||||
github.com/fumiama/go-base16384 v1.7.0
|
||||
github.com/fumiama/go-x25519 v1.0.0
|
||||
github.com/fumiama/orbyte v0.0.0-20250225122817-8c60967c655e
|
||||
github.com/fumiama/orbyte v0.0.0-20250225143058-717b07486e38
|
||||
github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac
|
||||
github.com/klauspost/compress v1.17.9
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
|
||||
4
go.sum
4
go.sum
@@ -11,8 +11,8 @@ github.com/fumiama/go-base16384 v1.7.0 h1:6fep7XPQWxRlh4Hu+KsdH+6+YdUp+w6CwRXtMW
|
||||
github.com/fumiama/go-base16384 v1.7.0/go.mod h1:OEn+947GV5gsbTAnyuUW/SrfxJYUdYupSIQXOuGOcXM=
|
||||
github.com/fumiama/go-x25519 v1.0.0 h1:hiGg9EhseVmGCc8T1jECVkj8Keu/aJ1ZK05RM8Vuavo=
|
||||
github.com/fumiama/go-x25519 v1.0.0/go.mod h1:8VOhfyGZzw4IUs4nCjQFqW9cA3V/QpSCtP3fo2dLNg4=
|
||||
github.com/fumiama/orbyte v0.0.0-20250225122817-8c60967c655e h1:fSshWA2ixEA97OQZZ1gj7xTb1lKxeWjFpOqvtjk/NHw=
|
||||
github.com/fumiama/orbyte v0.0.0-20250225122817-8c60967c655e/go.mod h1:qkUllQ1+gTx5sGrmKvIsqUgsnOO21Hiq847YHJRifbk=
|
||||
github.com/fumiama/orbyte v0.0.0-20250225143058-717b07486e38 h1:BZ4Hl4hKwdhbf3IzXLLJvAm8qoIm5+yudCLV245tN/0=
|
||||
github.com/fumiama/orbyte v0.0.0-20250225143058-717b07486e38/go.mod h1:qkUllQ1+gTx5sGrmKvIsqUgsnOO21Hiq847YHJRifbk=
|
||||
github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac h1:A/5A0rODsg+EQHH61Ew5mMUtDpRXaSNqHhPvW+fN4C4=
|
||||
github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac/go.mod h1:BBnNY9PwK+UUn4trAU+H0qsMEypm7+3Bj1bVFuJItlo=
|
||||
github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0 h1:WfrSFlIlCAtg6Rt2IGna0HhJYSDE45YVHiYqO4wwsEw=
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"math/bits"
|
||||
mrand "math/rand"
|
||||
"runtime"
|
||||
|
||||
"github.com/fumiama/orbyte/pbuf"
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -53,13 +52,13 @@ func expandkeyunit(v1, v2 byte) (v uint16) {
|
||||
return
|
||||
}
|
||||
|
||||
// Encode by aead and put b into pool
|
||||
func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb pbuf.Bytes) {
|
||||
// encode by aead and put b into pool
|
||||
func (l *Link) encode(teatype uint8, additional uint16, b []byte) (eb pbuf.Bytes) {
|
||||
if len(b) == 0 || teatype >= 32 {
|
||||
return
|
||||
}
|
||||
if l.keys[0] == nil {
|
||||
return pbuf.ParseBytes(b...)
|
||||
return pbuf.ParseBytes(b...).Copy()
|
||||
}
|
||||
aead := l.keys[teatype]
|
||||
if aead == nil {
|
||||
@@ -70,13 +69,13 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb pbuf.Bytes
|
||||
return
|
||||
}
|
||||
|
||||
// Decode by aead and put b into pool
|
||||
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db pbuf.Bytes, err error) {
|
||||
// decode by aead and put b into pool
|
||||
func (l *Link) decode(teatype uint8, additional uint16, b []byte) (db pbuf.Bytes, err error) {
|
||||
if len(b) == 0 || teatype >= 32 {
|
||||
return
|
||||
}
|
||||
if l.keys[0] == nil {
|
||||
return pbuf.ParseBytes(b...), nil
|
||||
return pbuf.ParseBytes(b...).Copy(), nil
|
||||
}
|
||||
aead := l.keys[teatype]
|
||||
if aead == nil {
|
||||
@@ -142,7 +141,6 @@ func (m *Me) xorenc(data []byte, seq uint32) pbuf.Bytes {
|
||||
}
|
||||
p := batchsz * 8
|
||||
copy(newdat.Bytes()[8+p:], data[p:])
|
||||
runtime.KeepAlive(data)
|
||||
newdat.Bytes()[newdat.Len()-1] = byte(remain)
|
||||
sum ^= binary.LittleEndian.Uint64(newdat.Bytes()[8+p:])
|
||||
binary.LittleEndian.PutUint64(newdat.Bytes()[8+p:], sum)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
@@ -32,7 +31,6 @@ func TestXOR(t *testing.T) {
|
||||
if !bytes.Equal(dec, r2.Bytes()) {
|
||||
t.Fatal("unexpected xor at", i, "except", hex.EncodeToString(r2.Bytes()), "got", hex.EncodeToString(dec))
|
||||
}
|
||||
runtime.KeepAlive(dec)
|
||||
if seq != uint32(i) {
|
||||
t.Fatal("unexpected xor at", i, "seq", seq)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -12,7 +10,6 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/fumiama/WireGold/config"
|
||||
@@ -179,7 +176,7 @@ func (m *Me) dispatch(packet *orbyte.Item[head.Packet], addr p2p.EndPoint, index
|
||||
}
|
||||
addt := pp.AdditionalData()
|
||||
var err error
|
||||
data, err := p.Decode(pp.CipherIndex(), addt, pp.Body())
|
||||
data, err := p.decode(pp.CipherIndex(), addt, pp.Body())
|
||||
if err != nil {
|
||||
if config.ShowDebugLog {
|
||||
logrus.Debugln("[listen] @", index, "drop invalid packet key idx:", pp.CipherIndex(), "addt:", addt, "err:", err)
|
||||
@@ -188,21 +185,14 @@ func (m *Me) dispatch(packet *orbyte.Item[head.Packet], addr p2p.EndPoint, index
|
||||
}
|
||||
pp.SetBody(data.Trans().Bytes())
|
||||
if p.usezstd {
|
||||
dec, _ := zstd.NewReader(bytes.NewReader(pp.Body()))
|
||||
var err error
|
||||
w := helper.SelectWriter()
|
||||
_, err = io.Copy(w, dec)
|
||||
dec.Close()
|
||||
dat, err := decodezstd(pp.Body())
|
||||
if err != nil {
|
||||
if config.ShowDebugLog {
|
||||
logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if config.ShowDebugLog {
|
||||
logrus.Debugln("[listen] @", index, "zstd decoded len:", w.Len())
|
||||
}
|
||||
pp.SetBody(w.TransBytes().Bytes())
|
||||
pp.SetBody(dat.Trans().Bytes())
|
||||
}
|
||||
if !pp.IsVaildHash() {
|
||||
if config.ShowDebugLog {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
@@ -11,7 +10,6 @@ import (
|
||||
"math/rand"
|
||||
"runtime"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/fumiama/WireGold/config"
|
||||
@@ -100,16 +98,12 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) {
|
||||
}
|
||||
data := p.Body()
|
||||
if l.usezstd {
|
||||
w := helper.SelectWriter()
|
||||
enc, _ := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
_, _ = io.Copy(enc, bytes.NewReader(data))
|
||||
enc.Close()
|
||||
data = w.TransBytes().Bytes()
|
||||
data = encodezstd(data).Trans().Bytes()
|
||||
if config.ShowDebugLog {
|
||||
logrus.Debugln("[send] data len after zstd:", len(data))
|
||||
}
|
||||
}
|
||||
p.SetBody(l.Encode(teatype, sndcnt&0x07ff, data).Trans().Bytes())
|
||||
p.SetBody(l.encode(teatype, sndcnt&0x07ff, data).Trans().Bytes())
|
||||
if config.ShowDebugLog {
|
||||
logrus.Debugln("[send] data len after xchacha20:", p.BodyLen(), "addt:", sndcnt)
|
||||
}
|
||||
|
||||
41
gold/link/zstd.go
Normal file
41
gold/link/zstd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/fumiama/orbyte/pbuf"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
func encodezstd(data []byte) pbuf.Bytes {
|
||||
w := helper.SelectWriter()
|
||||
enc, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = io.Copy(enc, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = enc.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return w.TransBytes()
|
||||
}
|
||||
|
||||
func decodezstd(data []byte) (pbuf.Bytes, error) {
|
||||
dec, err := zstd.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return pbuf.Bytes{}, err
|
||||
}
|
||||
w := helper.SelectWriter()
|
||||
_, err = io.Copy(w, dec)
|
||||
dec.Close()
|
||||
if err != nil {
|
||||
return pbuf.Bytes{}, err
|
||||
}
|
||||
return w.TransBytes(), nil
|
||||
}
|
||||
@@ -360,6 +360,9 @@ func benchmarkTunnel(b *testing.B, sz int, nw string, isplain, isbase14 bool, ps
|
||||
}
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
time.Sleep(time.Second) // wait packets all received
|
||||
|
||||
tunnme.Stop()
|
||||
tunnpeer.Stop()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user