1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-22 19:40:35 +08:00

fix: wrong drop of same crc packet

This commit is contained in:
源文雨
2024-07-31 20:54:20 +08:00
parent a4275beced
commit ecff222074
5 changed files with 41 additions and 26 deletions

View File

@@ -121,12 +121,13 @@ func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
} }
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data
func (m *Me) xorenc(data []byte) []byte { func (m *Me) xorenc(data []byte, seq uint32) []byte {
batchsz := len(data) / 8 batchsz := len(data) / 8
remain := len(data) % 8 remain := len(data) % 8
sum := m.mask sum := m.mask
newdat := helper.MakeBytes(len(data) + 8) newdat := helper.MakeBytes(len(data) + 8)
_, _ = rand.Read(newdat[:8]) binary.LittleEndian.PutUint32(newdat[:4], seq)
_, _ = rand.Read(newdat[4:8])
if remain > 0 { if remain > 0 {
var buf [8]byte var buf [8]byte
p := batchsz * 8 p := batchsz * 8
@@ -141,13 +142,15 @@ func (m *Me) xorenc(data []byte) []byte {
sum ^= binary.LittleEndian.Uint64(data[a:b]) sum ^= binary.LittleEndian.Uint64(data[a:b])
binary.LittleEndian.PutUint64(newdat[a+8:b+8], sum) binary.LittleEndian.PutUint64(newdat[a+8:b+8], sum)
} }
sum ^= binary.LittleEndian.Uint64(newdat[:8])
binary.LittleEndian.PutUint64(newdat[:8], sum)
return newdat return newdat
} }
// xordec 按 8 字节, 以初始 m.mask 循环异或解码 data // xordec 按 8 字节, 以初始 m.mask 循环异或解码 data
func (m *Me) xordec(data []byte) []byte { func (m *Me) xordec(data []byte) (uint32, []byte) {
if len(data) <= 8 { if len(data) <= 8 {
return nil return 0, nil
} }
batchsz := len(data) / 8 batchsz := len(data) / 8
remain := len(data) % 8 remain := len(data) % 8
@@ -178,5 +181,5 @@ func (m *Me) xordec(data []byte) []byte {
} else { } else {
binary.LittleEndian.PutUint64(data[len(data)-8:], next^m.mask) binary.LittleEndian.PutUint64(data[len(data)-8:], next^m.mask)
} }
return data[8:] return binary.LittleEndian.Uint32(data[:4]), data[8:]
} }

View File

@@ -27,8 +27,9 @@ func TestXOR(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(m.xordec(m.xorenc(r1.Bytes())), r2.Bytes()) { seq, dec := m.xordec(m.xorenc(r1.Bytes(), uint32(i)))
t.Fatal("unexpected xor at", i) if !bytes.Equal(dec, r2.Bytes()) || seq != uint32(i) {
t.Fatal("unexpected xor at", i, "seq", seq)
} }
} }
} }

View File

@@ -42,7 +42,7 @@ type Me struct {
// 本机路由表 // 本机路由表
router *Router router *Router
// 本机未接收完全分片池 // 本机未接收完全分片池
recving *ttl.Cache[[32]byte, *head.Packet] recving *ttl.Cache[uint64, *head.Packet]
// 抗重放攻击记录池 // 抗重放攻击记录池
recved *ttl.Cache[uint64, bool] recved *ttl.Cache[uint64, bool]
// 本机上层配置 // 本机上层配置
@@ -106,7 +106,7 @@ func NewMe(cfg *MyConfig) (m Me) {
var buf [8]byte var buf [8]byte
binary.BigEndian.PutUint64(buf[:], m.mask) binary.BigEndian.PutUint64(buf[:], m.mask)
logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:])) logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:]))
m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 30) m.recving = ttl.NewCache[uint64, *head.Packet](time.Second * 30)
m.recved = ttl.NewCache[uint64, bool](time.Second * 30) m.recved = ttl.NewCache[uint64, bool](time.Second * 30)
return return
} }

View File

@@ -3,8 +3,8 @@ package link
import ( import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"hash/crc64"
"strconv" "strconv"
"unsafe"
"github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/head"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -26,7 +26,7 @@ func (m *Me) wait(data []byte) *head.Packet {
endl = "." endl = "."
} }
logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl) logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl)
data = m.xordec(data) seq, data := m.xordec(data)
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl) logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
flags := head.Flags(data) flags := head.Flags(data)
if !flags.IsValid() { if !flags.IsValid() {
@@ -34,8 +34,10 @@ func (m *Me) wait(data []byte) *head.Packet {
return nil return nil
} }
crc := binary.LittleEndian.Uint64(data[52:head.PacketHeadLen]) crc := binary.LittleEndian.Uint64(data[52:head.PacketHeadLen])
logrus.Debugf("[recv] packet crc %016x", crc) crclog := crc
if m.recved.Get(crc) { // 是重放攻击 crc ^= (uint64(seq) << 16)
logrus.Debugf("[recv] packet crc %016x, seq %08x, xored crc %016x", crclog, seq, crc)
if m.recved.Get(crc) {
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16)) logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
return nil return nil
} }
@@ -51,17 +53,21 @@ func (m *Me) wait(data []byte) *head.Packet {
return h return h
} }
hashd := data[20:52] crchash := crc64.New(crc64.MakeTable(crc64.ISO))
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd))) _, _ = crchash.Write(data[20:52])
var buf [4]byte
binary.LittleEndian.PutUint32(buf[:], seq)
_, _ = crchash.Write(buf[:])
hsh := crchash.Sum64()
h := m.recving.Get(hsh) h := m.recving.Get(hsh)
if h != nil { if h != nil {
logrus.Debugln("[recv] get another frag part of", hex.EncodeToString(hashd)) logrus.Debugln("[recv] get another frag part of", strconv.FormatUint(hsh, 16))
ok, err := h.Unmarshal(data) ok, err := h.Unmarshal(data)
if err == nil { if err == nil {
if ok { if ok {
m.recving.Delete(hsh) m.recving.Delete(hsh)
m.recved.Set(crc, true) m.recved.Set(crc, true)
logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "has reached") logrus.Debugln("[recv] all parts of", strconv.FormatUint(hsh, 16), "has reached")
return h return h
} }
} else { } else {
@@ -70,7 +76,7 @@ func (m *Me) wait(data []byte) *head.Packet {
} }
return nil return nil
} }
logrus.Debugln("[recv] get new frag part of", hex.EncodeToString(hashd)) logrus.Debugln("[recv] get new frag part of", strconv.FormatUint(hsh, 16))
h = head.SelectPacket() h = head.SelectPacket()
_, err := h.Unmarshal(data) _, err := h.Unmarshal(data)
if err != nil { if err != nil {

View File

@@ -2,6 +2,8 @@ package link
import ( import (
"bytes" "bytes"
crand "crypto/rand"
"encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@@ -25,6 +27,9 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
defer p.Put() defer p.Put()
teatype := l.randkeyidx() teatype := l.randkeyidx()
sndcnt := uint16(l.incgetsndcnt()) sndcnt := uint16(l.incgetsndcnt())
var buf [4]byte
_, _ = crand.Read(buf[:])
seq := binary.BigEndian.Uint32(buf[:])
mtu := l.mtu mtu := l.mtu
if l.mturandomrange > 0 { if l.mturandomrange > 0 {
mtu -= uint16(rand.Intn(int(l.mturandomrange))) mtu -= uint16(rand.Intn(int(l.mturandomrange)))
@@ -40,7 +45,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
} }
remlen := p.BodyLen() remlen := p.BodyLen()
if remlen <= delta { if remlen <= delta {
return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false) return l.write(p, teatype, sndcnt, uint32(remlen), 0, istransfer, false, seq)
} }
if istransfer && p.Flags.DontFrag() && remlen > delta { if istransfer && p.Flags.DontFrag() && remlen > delta {
return 0, ErrDropBigDontFragPkt return 0, ErrDropBigDontFragPkt
@@ -53,7 +58,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
remlen -= delta remlen -= delta
logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen) logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", remlen)
packet.CropBody(pos, pos+delta) packet.CropBody(pos, pos+delta)
cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true) cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true, seq)
n += cnt n += cnt
if err != nil { if err != nil {
return n, err return n, err
@@ -66,7 +71,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]") logrus.Debugln("[send] last frag [", pos, "~", pos+remlen, "]")
p.CropBody(pos, pos+remlen) p.CropBody(pos, pos+remlen)
cnt := 0 cnt := 0
cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false) cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false, seq)
n += cnt n += cnt
} }
return n, err return n, err
@@ -90,7 +95,7 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) {
} }
// write 向 peer 发包 // write 向 peer 发包
func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) { func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) {
if p.DecreaseAndGetTTL() <= 0 { if p.DecreaseAndGetTTL() <= 0 {
return 0, ErrTTL return 0, ErrTTL
} }
@@ -98,14 +103,14 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
cpp := p.Copy() cpp := p.Copy()
_ = time.AfterFunc(time.Millisecond*(10+time.Duration(rand.Intn(40))), func() { _ = time.AfterFunc(time.Millisecond*(10+time.Duration(rand.Intn(40))), func() {
defer cpp.Put() defer cpp.Put()
_, _ = l.writeonce(cpp, teatype, additional, datasz, offset, istransfer, hasmore) _, _ = l.writeonce(cpp, teatype, additional, datasz, offset, istransfer, hasmore, seq)
}) })
} }
return l.writeonce(p, teatype, additional, datasz, offset, istransfer, hasmore) return l.writeonce(p, teatype, additional, datasz, offset, istransfer, hasmore, seq)
} }
// write 向 peer 发一个包 // write 向 peer 发一个包
func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (int, error) { func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool, seq uint32) (int, error) {
peerep := l.endpoint peerep := l.endpoint
if peerep == nil { if peerep == nil {
return 0, errors.New("nil endpoint of " + p.Dst.String()) return 0, errors.New("nil endpoint of " + p.Dst.String())
@@ -129,7 +134,7 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas
} }
logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.conn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.conn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
d = l.me.xorenc(d) d = l.me.xorenc(d, seq)
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
defer helper.PutBytes(d) defer helper.PutBytes(d)
return l.me.conn.WriteToPeer(d, peerep) return l.me.conn.WriteToPeer(d, peerep)