From 8163c38884f6aa2bc7cf6c4809e7d42516a85edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 4 Aug 2023 16:07:35 +0800 Subject: [PATCH] fix: xor --- gold/link/crypto.go | 49 ++++++++++++++++++++++++++++++++++------ gold/link/crypto_test.go | 16 +++++++++---- gold/link/me.go | 4 ++++ gold/link/recv.go | 4 ++-- gold/link/send.go | 6 ++--- 5 files changed, 62 insertions(+), 17 deletions(-) diff --git a/gold/link/crypto.go b/gold/link/crypto.go index d31005d..9cb6843 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -66,23 +66,58 @@ func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) { return } -// xor 按 8 字节, 以初始 m.mask 循环异或 data -func (m *Me) xor(data []byte) []byte { +// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data +func (m *Me) xorenc(data []byte) []byte { batchsz := len(data) / 8 remain := len(data) % 8 sum := m.mask - for i := 0; i < batchsz; i++ { + if remain > 0 { + var buf [8]byte + p := batchsz * 8 + copy(buf[:], data[p:]) + sum ^= binary.LittleEndian.Uint64(buf[:]) + binary.LittleEndian.PutUint64(buf[:], sum) + copy(data[p:], buf[:]) + } + for i := batchsz - 1; i >= 0; i-- { a := i * 8 b := (i + 1) * 8 sum ^= binary.LittleEndian.Uint64(data[a:b]) binary.LittleEndian.PutUint64(data[a:b], sum) } + return data +} + +// xordec 按 8 字节, 以初始 m.mask 循环异或解码 data +func (m *Me) xordec(data []byte) []byte { + batchsz := len(data) / 8 + remain := len(data) % 8 + this := uint64(0) + next := uint64(0) + if len(data) >= 8 { + next = binary.LittleEndian.Uint64(data[:8]) + } + for i := 0; i < batchsz-1; i++ { + a := i * 8 + b := (i + 1) * 8 + this = next + next = binary.LittleEndian.Uint64(data[a+8 : b+8]) + binary.LittleEndian.PutUint64(data[a:b], this^next) + } if remain > 0 { var buf [8]byte - copy(buf[:], data[remain:]) - sum ^= binary.LittleEndian.Uint64(buf[:]) - binary.LittleEndian.PutUint64(buf[:], sum) - copy(data[remain:], buf[:]) + a := (batchsz - 1) * 8 + b := batchsz * 8 + copy(buf[:], data[b:]) + this = next + next = binary.LittleEndian.Uint64(buf[:]) | (m.mask & (uint64(0xffffffff_ffffffff) << (uint64(remain) * 8))) + if batchsz > 0 { + binary.LittleEndian.PutUint64(data[a:b], this^next) + } + binary.LittleEndian.PutUint64(buf[:], next^m.mask) + copy(data[b:], buf[:]) + } else { + binary.LittleEndian.PutUint64(data[len(data)-8:], next^m.mask) } return data } diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index 2fbc4b3..8b3e58a 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -3,6 +3,7 @@ package link import ( "bytes" "crypto/rand" + "io" "testing" ) @@ -10,15 +11,20 @@ func TestXOR(t *testing.T) { m := Me{ mask: 0x12345678_90abcdef, } - buf := make([]byte, 65535) - for i := 1; i < 65536; i++ { + buf := make([]byte, 4096) + buf2 := make([]byte, 4096) + for i := 1; i < 4096; i++ { data := buf[:i] - _, err := rand.Read(data) + orgdata := buf2[:i] + r1 := bytes.NewBuffer(data[:0]) + r2 := bytes.NewBuffer(orgdata[:0]) + w := io.MultiWriter(r1, r2) + _, err := io.CopyN(w, rand.Reader, int64(i)) if err != nil { t.Fatal(err) } - if !bytes.Equal(m.xor(m.xor(data)), data) { - t.Fatal("unexpected xor at ", i) + if !bytes.Equal(m.xordec(m.xorenc(r1.Bytes())), r2.Bytes()) { + t.Fatal("unexpected xor at", i) } } } diff --git a/gold/link/me.go b/gold/link/me.go index 5fe9d09..b18bfb2 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -2,6 +2,7 @@ package link import ( "encoding/binary" + "encoding/hex" "io" "net" "strconv" @@ -99,6 +100,9 @@ func NewMe(cfg *MyConfig) (m Me) { m.dstport = cfg.DstPort m.mtu = cfg.MTU & 0xfff8 m.mask = cfg.Mask + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], m.mask) + logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:])) if m.writer == nil { m.writer = helper.SelectWriter() } diff --git a/gold/link/recv.go b/gold/link/recv.go index c889720..b8e28a7 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -18,14 +18,14 @@ func (m *Me) wait(data []byte) *head.Packet { if len(data) < 60 { // not a valid packet return nil } - bound := 256 + bound := 64 endl := "..." if len(data) < bound { bound = len(data) endl = "." } logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl) - data = m.xor(data) + data = m.xordec(data) logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl) flags := binary.LittleEndian.Uint16(data[10:12]) if flags&0x8000 == 0x8000 { // not a valid packet diff --git a/gold/link/send.go b/gold/link/send.go index 07e192d..ae91c74 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -99,7 +99,7 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional, mtu uint16, data if peerep == nil { return 0, errors.New("[send] nil endpoint of " + p.Dst.String()) } - bound := 256 + bound := 64 endl := "..." if len(d) < bound { bound = len(d) @@ -107,9 +107,9 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional, mtu uint16, data } logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) - d = l.me.xor(d) - n, err = l.me.myep.WriteToUDP(d, peerep) + d = l.me.xorenc(d) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) + n, err = l.me.myep.WriteToUDP(d, peerep) cl() } return