mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-21 02:40:24 +08:00
增加抗重放攻击
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"hash/crc64"
|
"hash/crc64"
|
||||||
"net"
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/helper"
|
"github.com/fumiama/WireGold/helper"
|
||||||
blake2b "github.com/fumiama/blake2b-simd"
|
blake2b "github.com/fumiama/blake2b-simd"
|
||||||
@@ -15,7 +16,8 @@ import (
|
|||||||
// Packet 是发送和接收的最小单位
|
// Packet 是发送和接收的最小单位
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
// TeaTypeDataSZ len(Data)
|
// TeaTypeDataSZ len(Data)
|
||||||
// 高 8 位指定加密所用 tea key
|
// 高 4 位指定加密所用 tea key
|
||||||
|
// 高 4-16 位是随机值
|
||||||
// 不得超过 65507-head 字节
|
// 不得超过 65507-head 字节
|
||||||
TeaTypeDataSZ uint32
|
TeaTypeDataSZ uint32
|
||||||
// Proto 详见 head
|
// Proto 详见 head
|
||||||
@@ -36,10 +38,12 @@ type Packet struct {
|
|||||||
// 生成时 Hash 全 0
|
// 生成时 Hash 全 0
|
||||||
// https://github.com/fumiama/blake2b-simd
|
// https://github.com/fumiama/blake2b-simd
|
||||||
Hash [32]byte
|
Hash [32]byte
|
||||||
|
// CRC64 包头字段的 checksum 值,可以认为在一定时间内唯一
|
||||||
|
CRC64 uint64
|
||||||
// Data 承载的数据
|
// Data 承载的数据
|
||||||
Data []byte
|
Data []byte
|
||||||
// 记录还有多少字节未到达
|
// 记录还有多少字节未到达
|
||||||
rembytes uint32
|
rembytes int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPacket 生成一个新包
|
// NewPacket 生成一个新包
|
||||||
@@ -66,16 +70,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sz := p.TeaTypeDataSZ & 0x00ffffff
|
sz := p.TeaTypeDataSZ & 0x0000ffff
|
||||||
if sz == 0 && len(p.Data) == 0 {
|
if sz == 0 && len(p.Data) == 0 {
|
||||||
p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4])
|
p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4])
|
||||||
sz = p.TeaTypeDataSZ & 0x00ffffff
|
sz = p.TeaTypeDataSZ & 0x0000ffff
|
||||||
if int(sz)+52 == len(data) {
|
if int(sz)+52 == len(data) {
|
||||||
p.Data = data[52:]
|
p.Data = data[52:]
|
||||||
p.rembytes = 0
|
p.rembytes = 0
|
||||||
} else {
|
} else {
|
||||||
p.Data = make([]byte, sz)
|
p.Data = make([]byte, sz)
|
||||||
p.rembytes = sz
|
p.rembytes = int(sz)
|
||||||
}
|
}
|
||||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||||
p.Proto = uint8(pt)
|
p.Proto = uint8(pt)
|
||||||
@@ -93,10 +97,11 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
|||||||
p.Dst = make(net.IP, 4)
|
p.Dst = make(net.IP, 4)
|
||||||
copy(p.Dst, data[16:20])
|
copy(p.Dst, data[16:20])
|
||||||
copy(p.Hash[:], data[20:52])
|
copy(p.Hash[:], data[20:52])
|
||||||
|
p.CRC64 = binary.LittleEndian.Uint64(data[52:60])
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.rembytes > 0 {
|
if p.rembytes > 0 {
|
||||||
p.rembytes -= uint32(copy(p.Data[flags<<3:], data[60:]))
|
p.rembytes -= copy(p.Data[flags<<3:], data[60:])
|
||||||
}
|
}
|
||||||
|
|
||||||
complete = p.rembytes == 0
|
complete = p.rembytes == 0
|
||||||
@@ -104,6 +109,8 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var counter uint32
|
||||||
|
|
||||||
// Marshal 将自身数据编码为 []byte
|
// Marshal 将自身数据编码为 []byte
|
||||||
// offset 必须为 8 的倍数,表示偏移的 8 位
|
// offset 必须为 8 的倍数,表示偏移的 8 位
|
||||||
func (p *Packet) Marshal(src net.IP, teatype uint8, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
|
func (p *Packet) Marshal(src net.IP, teatype uint8, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
|
||||||
@@ -113,7 +120,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, datasz uint32, offset uint16
|
|||||||
}
|
}
|
||||||
|
|
||||||
if src != nil {
|
if src != nil {
|
||||||
p.TeaTypeDataSZ = uint32(teatype)<<24 | datasz
|
p.TeaTypeDataSZ = uint32(teatype)<<28 | (atomic.AddUint32(&counter, 1)<<16)&0x0fff0000 | datasz
|
||||||
p.Src = src
|
p.Src = src
|
||||||
offset &= 0x1fff
|
offset &= 0x1fff
|
||||||
if dontfrag {
|
if dontfrag {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) {
|
|||||||
if packet == nil {
|
if packet == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sz := packet.TeaTypeDataSZ & 0x00ffffff
|
sz := packet.TeaTypeDataSZ & 0x0000ffff
|
||||||
r := int(sz) - len(packet.Data)
|
r := int(sz) - len(packet.Data)
|
||||||
if r > 0 {
|
if r > 0 {
|
||||||
logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it")
|
logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it")
|
||||||
@@ -61,7 +61,7 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) {
|
|||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case p.IsToMe(packet.Dst):
|
case p.IsToMe(packet.Dst):
|
||||||
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>24), packet.Data)
|
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data)
|
||||||
if !packet.IsVaildHash() {
|
if !packet.IsVaildHash() {
|
||||||
logrus.Debugln("[link] drop invalid packet")
|
logrus.Debugln("[link] drop invalid packet")
|
||||||
packet.Put()
|
packet.Put()
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ type Me struct {
|
|||||||
writer *helper.Writer
|
writer *helper.Writer
|
||||||
// 本机未接收完全分片池
|
// 本机未接收完全分片池
|
||||||
recving *ttl.Cache[[32]byte, *head.Packet]
|
recving *ttl.Cache[[32]byte, *head.Packet]
|
||||||
|
// 抗重放攻击记录池
|
||||||
|
recved *ttl.Cache[uint64, uint8]
|
||||||
// 本机上层配置
|
// 本机上层配置
|
||||||
srcport, dstport, mtu uint16
|
srcport, dstport, mtu uint16
|
||||||
}
|
}
|
||||||
@@ -96,7 +98,8 @@ func NewMe(cfg *MyConfig) (m Me) {
|
|||||||
if m.writer == nil {
|
if m.writer == nil {
|
||||||
m.writer = helper.SelectWriter()
|
m.writer = helper.SelectWriter()
|
||||||
}
|
}
|
||||||
m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 128)
|
m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 30)
|
||||||
|
m.recved = ttl.NewCache[uint64, uint8](time.Second * 30)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ func (m *Me) wait(data []byte) *head.Packet {
|
|||||||
if flags&0x8000 == 0x8000 { // not a valid packet
|
if flags&0x8000 == 0x8000 { // not a valid packet
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
crc := binary.LittleEndian.Uint64(data[52:60])
|
||||||
|
if m.recved.Get(crc) != 0 { // 是重放攻击
|
||||||
|
return nil
|
||||||
|
}
|
||||||
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12]))
|
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12]))
|
||||||
if flags == 0 || flags == 0x4000 {
|
if flags == 0 || flags == 0x4000 {
|
||||||
h := head.SelectPacket()
|
h := head.SelectPacket()
|
||||||
@@ -30,6 +34,7 @@ func (m *Me) wait(data []byte) *head.Packet {
|
|||||||
logrus.Errorln("[recv] unmarshal err:", err)
|
logrus.Errorln("[recv] unmarshal err:", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
m.recved.Set(crc, 1)
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,6 +47,7 @@ func (m *Me) wait(data []byte) *head.Packet {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
if ok {
|
if ok {
|
||||||
m.recving.Delete(hsh)
|
m.recving.Delete(hsh)
|
||||||
|
m.recved.Set(crc, 1)
|
||||||
logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "is reached")
|
logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "is reached")
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user