From ecff2220747447f1f165cf1130b0b4198ae8ccc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:54:20 +0800 Subject: [PATCH] fix: wrong drop of same crc packet --- gold/link/crypto.go | 13 ++++++++----- gold/link/crypto_test.go | 5 +++-- gold/link/me.go | 4 ++-- gold/link/recv.go | 24 +++++++++++++++--------- gold/link/send.go | 21 +++++++++++++-------- 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 1912240..5228aec 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -121,12 +121,13 @@ func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { } // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data -func (m *Me) xorenc(data []byte) []byte { +func (m *Me) xorenc(data []byte, seq uint32) []byte { batchsz := len(data) / 8 remain := len(data) % 8 sum := m.mask newdat := helper.MakeBytes(len(data) + 8) - _, _ = rand.Read(newdat[:8]) + binary.LittleEndian.PutUint32(newdat[:4], seq) + _, _ = rand.Read(newdat[4:8]) if remain > 0 { var buf [8]byte p := batchsz * 8 @@ -141,13 +142,15 @@ func (m *Me) xorenc(data []byte) []byte { sum ^= binary.LittleEndian.Uint64(data[a:b]) binary.LittleEndian.PutUint64(newdat[a+8:b+8], sum) } + sum ^= binary.LittleEndian.Uint64(newdat[:8]) + binary.LittleEndian.PutUint64(newdat[:8], sum) return newdat } // xordec 按 8 字节, 以初始 m.mask 循环异或解码 data -func (m *Me) xordec(data []byte) []byte { +func (m *Me) xordec(data []byte) (uint32, []byte) { if len(data) <= 8 { - return nil + return 0, nil } batchsz := len(data) / 8 remain := len(data) % 8 @@ -178,5 +181,5 @@ func (m *Me) xordec(data []byte) []byte { } else { binary.LittleEndian.PutUint64(data[len(data)-8:], next^m.mask) } - return data[8:] + return binary.LittleEndian.Uint32(data[:4]), data[8:] } diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index 4287f0d..cbb2b51 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -27,8 +27,9 @@ func TestXOR(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal(m.xordec(m.xorenc(r1.Bytes())), r2.Bytes()) { - t.Fatal("unexpected xor at", i) + seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i))) + if !bytes.Equal(dec, r2.Bytes()) || seq != uint32(i) { + t.Fatal("unexpected xor at", i, "seq", seq) } } } diff --git a/gold/link/me.go b/gold/link/me.go index 775db23..c14455c 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -42,7 +42,7 @@ type Me struct { // 本机路由表 router *Router // 本机未接收完全分片池 - recving *ttl.Cache[[32]byte, *head.Packet] + recving *ttl.Cache[uint64, *head.Packet] // 抗重放攻击记录池 recved *ttl.Cache[uint64, bool] // 本机上层配置 @@ -106,7 +106,7 @@ func NewMe(cfg *MyConfig) (m Me) { var buf [8]byte binary.BigEndian.PutUint64(buf[:], m.mask) logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:])) - m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 30) + m.recving = ttl.NewCache[uint64, *head.Packet](time.Second * 30) m.recved = ttl.NewCache[uint64, bool](time.Second * 30) return } diff --git a/gold/link/recv.go b/gold/link/recv.go index 52ee00e..0231491 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -3,8 +3,8 @@ package link import ( "encoding/binary" "encoding/hex" + "hash/crc64" "strconv" - "unsafe" "github.com/fumiama/WireGold/gold/head" "github.com/sirupsen/logrus" @@ -26,7 +26,7 @@ func (m *Me) wait(data []byte) *head.Packet { endl = "." } logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl) - data = m.xordec(data) + seq, data := m.xordec(data) logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl) flags := head.Flags(data) if !flags.IsValid() { @@ -34,8 +34,10 @@ func (m *Me) wait(data []byte) *head.Packet { return nil } crc := binary.LittleEndian.Uint64(data[52:head.PacketHeadLen]) - logrus.Debugf("[recv] packet crc %016x", crc) - if m.recved.Get(crc) { // 是重放攻击 + crclog := crc + crc ^= (uint64(seq) << 16) + logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc) + if m.recved.Get(crc) { logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) return nil } @@ -51,17 +53,21 @@ func (m *Me) wait(data []byte) *head.Packet { return h } - hashd := data[20:52] - hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd))) + crchash := crc64.New(crc64.MakeTable(crc64.ISO)) + _, _ = crchash.Write(data[20:52]) + var buf [4]byte + binary.LittleEndian.PutUint32(buf[:], seq) + _, _ = crchash.Write(buf[:]) + hsh := crchash.Sum64() h := m.recving.Get(hsh) if h != nil { - logrus.Debugln("[recv] get another frag part of", hex.EncodeToString(hashd)) + logrus.Debugln("[recv] get another frag part of", strconv.FormatUint(hsh, 16)) ok, err := h.Unmarshal(data) if err == nil { if ok { m.recving.Delete(hsh) m.recved.Set(crc, true) - logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "has reached") + logrus.Debugln("[recv] all parts of", strconv.FormatUint(hsh, 16), "has reached") return h } } else { @@ -70,7 +76,7 @@ func (m *Me) wait(data []byte) *head.Packet { } return nil } - logrus.Debugln("[recv] get new frag part of", hex.EncodeToString(hashd)) + logrus.Debugln("[recv] get new frag part of", strconv.FormatUint(hsh, 16)) h = head.SelectPacket() _, err := h.Unmarshal(data) if err != nil { diff --git a/gold/link/send.go b/gold/link/send.go index 16596e5..f5387f2 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -2,6 +2,8 @@ package link import ( "bytes" + crand "crypto/rand" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -25,6 +27,9 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { defer p.Put() teatype := l.randkeyidx() sndcnt := uint16(l.incgetsndcnt()) + var buf [4]byte + _, _ = crand.Read(buf[:]) + seq := binary.BigEndian.Uint32(buf[:]) mtu := l.mtu if l.mturandomrange > 0 { mtu -= uint16(rand.Intn(int(l.mturandomrange))) @@ -40,7 +45,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { } remlen := p.BodyLen() if remlen <= delta { - return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) + return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false, seq) } if istransfer && p.Flags.DontFrag() && remlen > delta { return 0, ErrDropBigDontFragPkt @@ -53,7 +58,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { remlen -= delta logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen) packet.CropBody(pos, pos+delta) - cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true) + cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true, seq) n += cnt if err != nil { return n, err @@ -66,7 +71,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]") p.CropBody(pos, pos+remlen) cnt := 0 - cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false) + cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false, seq) n += cnt } return n, err @@ -90,7 +95,7 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) { } // write 向 peer 发包 -func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) { +func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) { if p.DecreaseAndGetTTL() <= 0 { return 0, ErrTTL } @@ -98,14 +103,14 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui cpp := p.Copy() _ = time.AfterFunc(time.Millisecond*(10+time.Duration(rand.Intn(40))), func() { defer cpp.Put() - _, _ = l.writeonce(cpp, teatype, additional, datasz, offset, istransfer, hasmore) + _, _ = l.writeonce(cpp, teatype, additional, datasz, offset, istransfer, hasmore, seq) }) } - return l.writeonce(p, teatype, additional, datasz, offset, istransfer, hasmore) + return l.writeonce(p, teatype, additional, datasz, offset, istransfer, hasmore, seq) } // write 向 peer 发一个包 -func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) { +func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) { peerep := l.endpoint if peerep == nil { return 0, errors.New("nil endpoint of " + p.Dst.String()) @@ -129,7 +134,7 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas } logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.conn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) - d = l.me.xorenc(d) + d = l.me.xorenc(d, seq) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) defer helper.PutBytes(d) return l.me.conn.WriteToPeer(d, peerep)