mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-29 23:30:37 +08:00
fix(listen): allow nil body packets
This commit is contained in:
@@ -4,10 +4,15 @@ import (
|
|||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"math/bits"
|
"math/bits"
|
||||||
mrand "math/rand"
|
mrand "math/rand"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrCipherTextTooShort = errors.New("ciphertext too short")
|
||||||
|
)
|
||||||
|
|
||||||
func (l *Link) randkeyidx() uint8 {
|
func (l *Link) randkeyidx() uint8 {
|
||||||
if l.keys[1] == nil {
|
if l.keys[1] == nil {
|
||||||
return 0
|
return 0
|
||||||
@@ -63,7 +68,7 @@ func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Decode 使用 xchacha20poly1305 和密钥序列解密
|
// Decode 使用 xchacha20poly1305 和密钥序列解密
|
||||||
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
|
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, err error) {
|
||||||
if b == nil || teatype >= 32 {
|
if b == nil || teatype >= 32 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -75,8 +80,7 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
|
|||||||
if aead == nil {
|
if aead == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db = decode(aead, additional, b)
|
return decode(aead, additional, b)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// encode 使用 xchacha20poly1305 加密
|
// encode 使用 xchacha20poly1305 加密
|
||||||
@@ -96,18 +100,17 @@ func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// decode 使用 xchacha20poly1305 解密
|
// decode 使用 xchacha20poly1305 解密
|
||||||
func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) {
|
func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
|
||||||
nsz := aead.NonceSize()
|
nsz := aead.NonceSize()
|
||||||
if len(b) < nsz { // ciphertext too short
|
if len(b) < nsz {
|
||||||
return
|
return nil, ErrCipherTextTooShort
|
||||||
}
|
}
|
||||||
// Split nonce and ciphertext.
|
// Split nonce and ciphertext.
|
||||||
nonce, ciphertext := b[:nsz], b[nsz:]
|
nonce, ciphertext := b[:nsz], b[nsz:]
|
||||||
// Decrypt the message and check it wasn't tampered with.
|
// Decrypt the message and check it wasn't tampered with.
|
||||||
var buf [2]byte
|
var buf [2]byte
|
||||||
binary.LittleEndian.PutUint16(buf[:], additional)
|
binary.LittleEndian.PutUint16(buf[:], additional)
|
||||||
db, _ = aead.Open(nil, nonce, ciphertext, buf[:])
|
return aead.Open(nil, nonce, ciphertext, buf[:])
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data
|
// xorenc 按 8 字节, 以初始 m.mask 循环异或编码 data
|
||||||
|
|||||||
@@ -45,7 +45,11 @@ func TestXChacha20(t *testing.T) {
|
|||||||
}
|
}
|
||||||
data := []byte("12345678")
|
data := []byte("12345678")
|
||||||
for i := uint64(0); i < 100000; i++ {
|
for i := uint64(0); i < 100000; i++ {
|
||||||
if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) {
|
db, err := decode(aead, uint16(i), encode(aead, uint16(i), data))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(db, data) {
|
||||||
t.Fatal("unexpected preshared at", i, "addt", uint16(i))
|
t.Fatal("unexpected preshared at", i, "addt", uint16(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
|
|||||||
logrus.Errorln("[listen] reconnect udp err:", err)
|
logrus.Errorln("[listen] reconnect udp err:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logrus.Debugln("[listen] unlock index", i)
|
||||||
hasntfinished[i].Unlock()
|
hasntfinished[i].Unlock()
|
||||||
i--
|
i--
|
||||||
continue
|
continue
|
||||||
@@ -75,6 +76,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
|
|||||||
}
|
}
|
||||||
packet := m.wait(lbf[:n])
|
packet := m.wait(lbf[:n])
|
||||||
if packet == nil {
|
if packet == nil {
|
||||||
|
logrus.Debugln("[listen] unlock index", i)
|
||||||
hasntfinished[i].Unlock()
|
hasntfinished[i].Unlock()
|
||||||
i--
|
i--
|
||||||
continue
|
continue
|
||||||
@@ -87,6 +89,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
|
|||||||
|
|
||||||
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
|
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
|
||||||
defer finish()
|
defer finish()
|
||||||
|
defer logrus.Debugln("[listen] unlock index", index)
|
||||||
sz := packet.TeaTypeDataSZ & 0x0000ffff
|
sz := packet.TeaTypeDataSZ & 0x0000ffff
|
||||||
r := int(sz) - len(packet.Data)
|
r := int(sz) - len(packet.Data)
|
||||||
if r > 0 {
|
if r > 0 {
|
||||||
@@ -108,9 +111,10 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
|
|||||||
switch {
|
switch {
|
||||||
case p.IsToMe(packet.Dst):
|
case p.IsToMe(packet.Dst):
|
||||||
addt := packet.AdditionalData()
|
addt := packet.AdditionalData()
|
||||||
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
|
var err error
|
||||||
if packet.Data == nil {
|
packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
|
||||||
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt)
|
if err != nil {
|
||||||
|
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
|
||||||
packet.Put()
|
packet.Put()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user