1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-29 23:30:37 +08:00

fix(listen): allow nil body packets

This commit is contained in:
源文雨
2024-07-12 22:32:11 +09:00
parent da6ffcc283
commit 9336ab61e8
3 changed files with 23 additions and 12 deletions

View File

@@ -4,10 +4,15 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"errors"
"math/bits" "math/bits"
mrand "math/rand" mrand "math/rand"
) )
var (
ErrCipherTextTooShort = errors.New("ciphertext too short")
)
func (l *Link) randkeyidx() uint8 { func (l *Link) randkeyidx() uint8 {
if l.keys[1] == nil { if l.keys[1] == nil {
return 0 return 0
@@ -63,7 +68,7 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
} }
// Decode 使用 xchacha20poly1305 和密钥序列解密 // Decode 使用 xchacha20poly1305 和密钥序列解密
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) { func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) {
if b == nil || teatype >= 32 { if b == nil || teatype >= 32 {
return return
} }
@@ -75,8 +80,7 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
if aead == nil { if aead == nil {
return return
} }
db = decode(aead, additional, b) return decode(aead, additional, b)
return
} }
// encode 使用 xchacha20poly1305 加密 // encode 使用 xchacha20poly1305 加密
@@ -96,18 +100,17 @@ func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) {
} }
// decode 使用 xchacha20poly1305 解密 // decode 使用 xchacha20poly1305 解密
func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) { func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
nsz := aead.NonceSize() nsz := aead.NonceSize()
if len(b) < nsz { // ciphertext too short if len(b) < nsz {
return return nil, ErrCipherTextTooShort
} }
// Split nonce and ciphertext. // Split nonce and ciphertext.
nonce, ciphertext := b[:nsz], b[nsz:] nonce, ciphertext := b[:nsz], b[nsz:]
// Decrypt the message and check it wasn't tampered with. // Decrypt the message and check it wasn't tampered with.
var buf [2]byte var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional) binary.LittleEndian.PutUint16(buf[:], additional)
db, _ = aead.Open(nil, nonce, ciphertext, buf[:]) return aead.Open(nil, nonce, ciphertext, buf[:])
return
} }
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data // xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data

View File

@@ -45,7 +45,11 @@ func TestXChacha20(t *testing.T) {
} }
data := []byte("12345678") data := []byte("12345678")
for i := uint64(0); i < 100000; i++ { for i := uint64(0); i < 100000; i++ {
if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) { db, err := decode(aead, uint16(i), encode(aead, uint16(i), data))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(db, data) {
t.Fatal("unexpected preshared at", i, "addt", uint16(i)) t.Fatal("unexpected preshared at", i, "addt", uint16(i))
} }
} }

View File

@@ -61,6 +61,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
logrus.Errorln("[listen] reconnect udp err:", err) logrus.Errorln("[listen] reconnect udp err:", err)
return return
} }
logrus.Debugln("[listen] unlock index", i)
hasntfinished[i].Unlock() hasntfinished[i].Unlock()
i-- i--
continue continue
@@ -75,6 +76,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
} }
packet := m.wait(lbf[:n]) packet := m.wait(lbf[:n])
if packet == nil { if packet == nil {
logrus.Debugln("[listen] unlock index", i)
hasntfinished[i].Unlock() hasntfinished[i].Unlock()
i-- i--
continue continue
@@ -87,6 +89,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) { func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
defer finish() defer finish()
defer logrus.Debugln("[listen] unlock index", index)
sz := packet.TeaTypeDataSZ & 0x0000ffff sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data) r := int(sz) - len(packet.Data)
if r > 0 { if r > 0 {
@@ -108,9 +111,10 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
switch { switch {
case p.IsToMe(packet.Dst): case p.IsToMe(packet.Dst):
addt := packet.AdditionalData() addt := packet.AdditionalData()
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data) var err error
if packet.Data == nil { packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt) if err != nil {
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
packet.Put() packet.Put()
return return
} }