1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-23 03:50:32 +08:00

add 分片

This commit is contained in:
fumiama
2021-12-31 12:34:54 +08:00
parent 7a30bfb1e6
commit d108bb81b4
11 changed files with 336 additions and 147 deletions

View File

@@ -12,9 +12,11 @@ import (
// Packet 是发送和接收的最小单位
type Packet struct {
// Ver 协议版本
Ver uint16
// DataSZ len(Data)
// 不得超过 65507-head 字节
DataSZ uint32
DataSZ uint16
// Proto 详见 head
Proto uint8
// TTL is time to live
@@ -23,9 +25,11 @@ type Packet struct {
SrcPort uint16
// DstPort 目的端口
DstPort uint16
// Src 源 ip
// Flags 高3位为标志(xDM)低13位为分片偏移
Flags uint16
// Src 源 ip (ipv4)
Src net.IP
// Dst 目的 ip
// Dst 目的 ip (ipv4)
Dst net.IP
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
// 生成时 Hash 全 0
@@ -33,12 +37,15 @@ type Packet struct {
Hash [32]byte
// Data 承载的数据
Data []byte
// 记录还有多少字节未到达
rembytes uint16
}
// NewPacket 生成一个新包
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
return &Packet{
Ver: 1,
Proto: proto,
TTL: 16,
SrcPort: srcPort,
@@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
}
// Unmarshal 将 data 的数据解码到自身
func (p *Packet) Unmarshal(data []byte) error {
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
if len(data) < 12 {
return errors.New("data len < 12")
err = errors.New("data len < 12")
return
}
p.DataSZ = binary.LittleEndian.Uint32(data[:4])
pt := binary.LittleEndian.Uint16(data[4:6])
p.Proto = uint8(pt)
p.TTL = uint8(pt >> 8)
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
sdl := binary.LittleEndian.Uint16(data[10:12])
srclen := uint8(sdl)
dstlen := uint8(sdl >> 8)
if len(data) < int(12+srclen+dstlen) {
return errors.New("data src or dst len mismatch")
if p.DataSZ == 0 && len(p.Data) == 0 {
p.Ver = binary.LittleEndian.Uint16(data[:2])
if p.Ver != 1 {
err = errors.New("unknown protocol version")
return
}
p.DataSZ = binary.LittleEndian.Uint16(data[2:4])
p.Data = make([]byte, p.DataSZ)
pt := binary.LittleEndian.Uint16(data[4:6])
p.Proto = uint8(pt)
p.TTL = uint8(pt >> 8)
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
p.rembytes = p.DataSZ
}
if srclen > 0 {
p.Src = make(net.IP, srclen)
copy(p.Src, data[12:12+srclen])
}
if dstlen > 0 {
p.Dst = make(net.IP, dstlen)
copy(p.Dst, data[12+srclen:12+srclen+dstlen])
}
copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32])
p.Data = data[12+srclen+dstlen+32:]
return nil
p.Flags = binary.LittleEndian.Uint16(data[10:12])
p.Src = make(net.IP, 4)
copy(p.Src, data[12:16])
p.Dst = make(net.IP, 4)
copy(p.Dst, data[16:20])
copy(p.Hash[:], data[20:52])
p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:]))
complete = p.rembytes == 0
return
}
// Marshal 将自身数据编码为 []byte
func (p *Packet) Marshal(src net.IP) []byte {
// offset 必须为 8 的倍数,表示偏移的 8 位
func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte {
p.TTL--
if p.TTL == 0 {
return nil
}
p.DataSZ = uint32(len(p.Data))
p.DataSZ = uint16(len(p.Data))
if src != nil {
p.Src = src
offset >>= 3
if dontfrag {
offset |= 0x4000
}
if hasmore {
offset |= 0x2000
}
p.Flags = offset
}
packet := make([]byte, 52+len(p.Data))
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ)
binary.LittleEndian.PutUint16(packet[:2], p.Ver)
binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ)
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
binary.LittleEndian.PutUint16(packet[10:12], 0x0404)
binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
copy(packet[12:16], p.Src.To4())
copy(packet[16:20], p.Dst.To4())
copy(packet[20:52], p.Hash[:])

View File

@@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet {
// Write 向 peer 发包
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
if len(p.Data) <= (32768 - 64) {
return l.write(p, istransfer)
if len(p.Data) <= int(l.me.mtu) {
return l.write(p, 0, istransfer, false)
}
data := p.Data
offset := 0
for len(data) > (32768 - 64) {
for len(data) > int(l.me.mtu) {
packet := *p
packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)]
i, err := l.write(&packet, istransfer)
packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)]
i, err := l.write(&packet, uint16(offset), istransfer, true)
n += i
if err != nil {
return n, err
}
data = data[(offset+1)*(32768-64):]
data = data[(offset+1)*int(l.me.mtu):]
}
return n, nil
}
@@ -107,14 +107,17 @@ func (l *Link) String() (n string) {
}
// write 向 peer 发一个包
func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) {
func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) {
var d []byte
if istransfer {
d = p.Marshal(nil)
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
}
d = p.Marshal(nil, 0, false, false)
} else {
p.FillHash()
p.Data = l.Encode(p.Data)
d = p.Marshal(l.me.me)
d = p.Marshal(l.me.me, offset, false, hasmore)
}
if d == nil {
return 0, errors.New("[link] ttl exceeded")

View File

@@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
n, addr, err := conn.ReadFromUDP(lbf)
if err == nil {
lbf = lbf[:n]
packet := head.Packet{}
err = packet.Unmarshal(lbf)
if err == nil {
packet := m.wait(lbf)
if packet != nil {
r := int(packet.DataSZ) - len(packet.Data)
if r > 0 {
remain, err := readAll(conn, r)
@@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
}
case head.ProtoNotify:
logrus.Infoln("[link] recv notify")
p.onNotify(&packet)
p.onNotify(packet)
case head.ProtoQuery:
logrus.Infoln("[link] recv query")
p.onQuery(&packet)
p.onQuery(packet)
case head.ProtoData:
if p.pipe != nil {
p.pipe <- &packet
p.pipe <- packet
logrus.Infoln("[link] deliver to pipe of", p.peerip)
} else {
m.pipe <- &packet
m.pipe <- packet
logrus.Infoln("[link] deliver to pipe of me")
}
default:
@@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
} else if p.Accept(packet.Dst) {
if p.allowtrans {
// 转发
n, err = p.Write(&packet, true)
n, err = p.Write(packet, true)
if err == nil {
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else {

View File

@@ -32,10 +32,17 @@ type Me struct {
pipe chan *head.Packet
// 本机路由表
router *Router
// 本机未接收完全分片池
recving map[[32]byte]*head.Packet
recvmu sync.Mutex
// 超时定时器
clock map[*head.Packet]uint8
// 本机上层配置
srcport, dstport, mtu uint16
}
// NewMe 设置本机参数
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) {
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) {
m.privKey = *privateKey
var err error
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
@@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei
}
m.router.SetDefault(nil)
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
m.srcport = srcport
m.dstport = dstport
m.mtu = mtu
return
}
func (m *Me) SrcPort() uint16 {
return m.srcport
}
func (m *Me) DstPort() uint16 {
return m.dstport
}
func (m *Me) MTU() uint16 {
return m.mtu
}

77
gold/link/recv.go Normal file
View File

@@ -0,0 +1,77 @@
package link
import (
"encoding/binary"
"time"
"unsafe"
"github.com/fumiama/WireGold/gold/head"
"github.com/sirupsen/logrus"
)
func (m *Me) initrecvpool() {
if m.recving == nil {
m.recving = make(map[[32]byte]*head.Packet, 128)
}
// 超时定时器
m.clock = make(map[*head.Packet]uint8, 128)
var delhs []*head.Packet
t := time.NewTicker(time.Second)
for range t.C {
m.recvmu.Lock()
for k, v := range m.clock {
if v > 10 { // 10s
delete(m.recving, k.Hash)
delhs = append(delhs, k)
} else {
m.clock[k]++
}
}
for _, k := range delhs {
delete(m.clock, k)
logrus.Warnln("[recv] drop timeout packet from", k.Src)
}
delhs = delhs[:0]
m.recvmu.Unlock()
}
}
func (m *Me) wait(data []byte) *head.Packet {
flags := binary.LittleEndian.Uint16(data[10:12])
if flags == 0 || flags == 0x4000 {
h := &head.Packet{}
_, err := h.Unmarshal(data)
if err != nil {
logrus.Errorln("[recv] unmarshal err:", err)
return nil
}
return h
}
m.recvmu.Lock()
defer m.recvmu.Unlock()
hashd := data[20:52]
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd)))
h, ok := m.recving[hsh]
if ok {
ok, err := h.Unmarshal(data)
if err == nil {
if ok {
return h
}
m.clock[h] = 0
} else {
logrus.Errorln("[recv] unmarshal err:", err)
}
return nil
}
h = &head.Packet{}
_, err := h.Unmarshal(data)
if err != nil {
logrus.Errorln("[recv] unmarshal err:", err)
return nil
}
m.recving[hsh] = h
m.clock[h] = 0
return nil
}