1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-04 23:40:26 +08:00

fix(listen): allow nil body packets

This commit is contained in:
源文雨
2024-07-12 22:32:11 +09:00
parent da6ffcc283
commit 9336ab61e8
3 changed files with 23 additions and 12 deletions

View File

@@ -4,10 +4,15 @@ import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"math/bits"
mrand "math/rand"
)
var (
ErrCipherTextTooShort = errors.New("ciphertext too short")
)
func (l *Link) randkeyidx() uint8 {
if l.keys[1] == nil {
return 0
@@ -63,7 +68,7 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
}
// Decode 使用 xchacha20poly1305 和密钥序列解密
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) {
if b == nil || teatype >= 32 {
return
}
@@ -75,8 +80,7 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
if aead == nil {
return
}
db = decode(aead, additional, b)
return
return decode(aead, additional, b)
}
// encode 使用 xchacha20poly1305 加密
@@ -96,18 +100,17 @@ func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) {
}
// decode 使用 xchacha20poly1305 解密
func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) {
func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
nsz := aead.NonceSize()
if len(b) < nsz { // ciphertext too short
return
if len(b) < nsz {
return nil, ErrCipherTextTooShort
}
// Split nonce and ciphertext.
nonce, ciphertext := b[:nsz], b[nsz:]
// Decrypt the message and check it wasn't tampered with.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
db, _ = aead.Open(nil, nonce, ciphertext, buf[:])
return
return aead.Open(nil, nonce, ciphertext, buf[:])
}
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data

View File

@@ -45,7 +45,11 @@ func TestXChacha20(t *testing.T) {
}
data := []byte("12345678")
for i := uint64(0); i < 100000; i++ {
if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) {
db, err := decode(aead, uint16(i), encode(aead, uint16(i), data))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(db, data) {
t.Fatal("unexpected preshared at", i, "addt", uint16(i))
}
}

View File

@@ -61,6 +61,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
logrus.Errorln("[listen] reconnect udp err:", err)
return
}
logrus.Debugln("[listen] unlock index", i)
hasntfinished[i].Unlock()
i--
continue
@@ -75,6 +76,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
}
packet := m.wait(lbf[:n])
if packet == nil {
logrus.Debugln("[listen] unlock index", i)
hasntfinished[i].Unlock()
i--
continue
@@ -87,6 +89,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
defer finish()
defer logrus.Debugln("[listen] unlock index", index)
sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data)
if r > 0 {
@@ -108,9 +111,10 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
switch {
case p.IsToMe(packet.Dst):
addt := packet.AdditionalData()
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt)
var err error
packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
if err != nil {
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
packet.Put()
return
}