From 9336ab61e84b172484ab67957e90599adba20ecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 12 Jul 2024 22:32:11 +0900 Subject: [PATCH] fix(listen): allow nil body packets --- gold/link/crypto.go | 19 +++++++++++-------- gold/link/crypto_test.go | 6 +++++- gold/link/listen.go | 10 +++++++--- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 303ff40..2dcbe5e 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -4,10 +4,15 @@ import ( "crypto/cipher" "crypto/rand" "encoding/binary" + "errors" "math/bits" mrand "math/rand" ) +var ( + ErrCipherTextTooShort = errors.New("ciphertext too short") +) + func (l *Link) randkeyidx() uint8 { if l.keys[1] == nil { return 0 @@ -63,7 +68,7 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { } // Decode 使用 xchacha20poly1305 和密钥序列解密 -func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) { +func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) { if b == nil || teatype >= 32 { return } @@ -75,8 +80,7 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) { if aead == nil { return } - db = decode(aead, additional, b) - return + return decode(aead, additional, b) } // encode 使用 xchacha20poly1305 加密 @@ -96,18 +100,17 @@ func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) { } // decode 使用 xchacha20poly1305 解密 -func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) { +func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { nsz := aead.NonceSize() - if len(b) < nsz { // ciphertext too short - return + if len(b) < nsz { + return nil, ErrCipherTextTooShort } // Split nonce and ciphertext. nonce, ciphertext := b[:nsz], b[nsz:] // Decrypt the message and check it wasn't tampered with. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) - db, _ = aead.Open(nil, nonce, ciphertext, buf[:]) - return + return aead.Open(nil, nonce, ciphertext, buf[:]) } // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index 0cb46a5..c3b45a8 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -45,7 +45,11 @@ func TestXChacha20(t *testing.T) { } data := []byte("12345678") for i := uint64(0); i < 100000; i++ { - if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) { + db, err := decode(aead, uint16(i), encode(aead, uint16(i), data)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(db, data) { t.Fatal("unexpected preshared at", i, "addt", uint16(i)) } } diff --git a/gold/link/listen.go b/gold/link/listen.go index 42f3c5a..e8933eb 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -61,6 +61,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) { logrus.Errorln("[listen] reconnect udp err:", err) return } + logrus.Debugln("[listen] unlock index", i) hasntfinished[i].Unlock() i-- continue @@ -75,6 +76,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) { } packet := m.wait(lbf[:n]) if packet == nil { + logrus.Debugln("[listen] unlock index", i) hasntfinished[i].Unlock() i-- continue @@ -87,6 +89,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) { func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) { defer finish() + defer logrus.Debugln("[listen] unlock index", index) sz := packet.TeaTypeDataSZ & 0x0000ffff r := int(sz) - len(packet.Data) if r > 0 { @@ -108,9 +111,10 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin switch { case p.IsToMe(packet.Dst): addt := packet.AdditionalData() - packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data) - if packet.Data == nil { - logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt) + var err error + packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data) + if err != nil { + logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err) packet.Put() return }