mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-23 03:50:32 +08:00
add 分片
This commit is contained in:
@@ -12,9 +12,11 @@ import (
|
||||
|
||||
// Packet 是发送和接收的最小单位
|
||||
type Packet struct {
|
||||
// Ver 协议版本
|
||||
Ver uint16
|
||||
// DataSZ len(Data)
|
||||
// 不得超过 65507-head 字节
|
||||
DataSZ uint32
|
||||
DataSZ uint16
|
||||
// Proto 详见 head
|
||||
Proto uint8
|
||||
// TTL is time to live
|
||||
@@ -23,9 +25,11 @@ type Packet struct {
|
||||
SrcPort uint16
|
||||
// DstPort 目的端口
|
||||
DstPort uint16
|
||||
// Src 源 ip
|
||||
// Flags 高3位为标志(xDM),低13位为分片偏移
|
||||
Flags uint16
|
||||
// Src 源 ip (ipv4)
|
||||
Src net.IP
|
||||
// Dst 目的 ip
|
||||
// Dst 目的 ip (ipv4)
|
||||
Dst net.IP
|
||||
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
|
||||
// 生成时 Hash 全 0
|
||||
@@ -33,12 +37,15 @@ type Packet struct {
|
||||
Hash [32]byte
|
||||
// Data 承载的数据
|
||||
Data []byte
|
||||
// 记录还有多少字节未到达
|
||||
rembytes uint16
|
||||
}
|
||||
|
||||
// NewPacket 生成一个新包
|
||||
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
|
||||
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
|
||||
return &Packet{
|
||||
Ver: 1,
|
||||
Proto: proto,
|
||||
TTL: 16,
|
||||
SrcPort: srcPort,
|
||||
@@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
|
||||
}
|
||||
|
||||
// Unmarshal 将 data 的数据解码到自身
|
||||
func (p *Packet) Unmarshal(data []byte) error {
|
||||
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
if len(data) < 12 {
|
||||
return errors.New("data len < 12")
|
||||
err = errors.New("data len < 12")
|
||||
return
|
||||
}
|
||||
p.DataSZ = binary.LittleEndian.Uint32(data[:4])
|
||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||
p.Proto = uint8(pt)
|
||||
p.TTL = uint8(pt >> 8)
|
||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
sdl := binary.LittleEndian.Uint16(data[10:12])
|
||||
srclen := uint8(sdl)
|
||||
dstlen := uint8(sdl >> 8)
|
||||
if len(data) < int(12+srclen+dstlen) {
|
||||
return errors.New("data src or dst len mismatch")
|
||||
if p.DataSZ == 0 && len(p.Data) == 0 {
|
||||
p.Ver = binary.LittleEndian.Uint16(data[:2])
|
||||
if p.Ver != 1 {
|
||||
err = errors.New("unknown protocol version")
|
||||
return
|
||||
}
|
||||
p.DataSZ = binary.LittleEndian.Uint16(data[2:4])
|
||||
p.Data = make([]byte, p.DataSZ)
|
||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||
p.Proto = uint8(pt)
|
||||
p.TTL = uint8(pt >> 8)
|
||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
p.rembytes = p.DataSZ
|
||||
}
|
||||
if srclen > 0 {
|
||||
p.Src = make(net.IP, srclen)
|
||||
copy(p.Src, data[12:12+srclen])
|
||||
}
|
||||
if dstlen > 0 {
|
||||
p.Dst = make(net.IP, dstlen)
|
||||
copy(p.Dst, data[12+srclen:12+srclen+dstlen])
|
||||
}
|
||||
copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32])
|
||||
p.Data = data[12+srclen+dstlen+32:]
|
||||
return nil
|
||||
|
||||
p.Flags = binary.LittleEndian.Uint16(data[10:12])
|
||||
|
||||
p.Src = make(net.IP, 4)
|
||||
copy(p.Src, data[12:16])
|
||||
p.Dst = make(net.IP, 4)
|
||||
copy(p.Dst, data[16:20])
|
||||
copy(p.Hash[:], data[20:52])
|
||||
p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:]))
|
||||
|
||||
complete = p.rembytes == 0
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal 将自身数据编码为 []byte
|
||||
func (p *Packet) Marshal(src net.IP) []byte {
|
||||
// offset 必须为 8 的倍数,表示偏移的 8 位
|
||||
func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte {
|
||||
p.TTL--
|
||||
if p.TTL == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.DataSZ = uint32(len(p.Data))
|
||||
p.DataSZ = uint16(len(p.Data))
|
||||
if src != nil {
|
||||
p.Src = src
|
||||
offset >>= 3
|
||||
if dontfrag {
|
||||
offset |= 0x4000
|
||||
}
|
||||
if hasmore {
|
||||
offset |= 0x2000
|
||||
}
|
||||
p.Flags = offset
|
||||
}
|
||||
|
||||
packet := make([]byte, 52+len(p.Data))
|
||||
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ)
|
||||
binary.LittleEndian.PutUint16(packet[:2], p.Ver)
|
||||
binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ)
|
||||
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
|
||||
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
|
||||
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
|
||||
binary.LittleEndian.PutUint16(packet[10:12], 0x0404)
|
||||
binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
|
||||
copy(packet[12:16], p.Src.To4())
|
||||
copy(packet[16:20], p.Dst.To4())
|
||||
copy(packet[20:52], p.Hash[:])
|
||||
|
||||
@@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet {
|
||||
|
||||
// Write 向 peer 发包
|
||||
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
if len(p.Data) <= (32768 - 64) {
|
||||
return l.write(p, istransfer)
|
||||
if len(p.Data) <= int(l.me.mtu) {
|
||||
return l.write(p, 0, istransfer, false)
|
||||
}
|
||||
data := p.Data
|
||||
offset := 0
|
||||
for len(data) > (32768 - 64) {
|
||||
for len(data) > int(l.me.mtu) {
|
||||
packet := *p
|
||||
packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)]
|
||||
i, err := l.write(&packet, istransfer)
|
||||
packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)]
|
||||
i, err := l.write(&packet, uint16(offset), istransfer, true)
|
||||
n += i
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
data = data[(offset+1)*(32768-64):]
|
||||
data = data[(offset+1)*int(l.me.mtu):]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@@ -107,14 +107,17 @@ func (l *Link) String() (n string) {
|
||||
}
|
||||
|
||||
// write 向 peer 发一个包
|
||||
func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) {
|
||||
var d []byte
|
||||
if istransfer {
|
||||
d = p.Marshal(nil)
|
||||
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
|
||||
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
|
||||
}
|
||||
d = p.Marshal(nil, 0, false, false)
|
||||
} else {
|
||||
p.FillHash()
|
||||
p.Data = l.Encode(p.Data)
|
||||
d = p.Marshal(l.me.me)
|
||||
d = p.Marshal(l.me.me, offset, false, hasmore)
|
||||
}
|
||||
if d == nil {
|
||||
return 0, errors.New("[link] ttl exceeded")
|
||||
|
||||
@@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
n, addr, err := conn.ReadFromUDP(lbf)
|
||||
if err == nil {
|
||||
lbf = lbf[:n]
|
||||
packet := head.Packet{}
|
||||
err = packet.Unmarshal(lbf)
|
||||
if err == nil {
|
||||
packet := m.wait(lbf)
|
||||
if packet != nil {
|
||||
r := int(packet.DataSZ) - len(packet.Data)
|
||||
if r > 0 {
|
||||
remain, err := readAll(conn, r)
|
||||
@@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
}
|
||||
case head.ProtoNotify:
|
||||
logrus.Infoln("[link] recv notify")
|
||||
p.onNotify(&packet)
|
||||
p.onNotify(packet)
|
||||
case head.ProtoQuery:
|
||||
logrus.Infoln("[link] recv query")
|
||||
p.onQuery(&packet)
|
||||
p.onQuery(packet)
|
||||
case head.ProtoData:
|
||||
if p.pipe != nil {
|
||||
p.pipe <- &packet
|
||||
p.pipe <- packet
|
||||
logrus.Infoln("[link] deliver to pipe of", p.peerip)
|
||||
} else {
|
||||
m.pipe <- &packet
|
||||
m.pipe <- packet
|
||||
logrus.Infoln("[link] deliver to pipe of me")
|
||||
}
|
||||
default:
|
||||
@@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
} else if p.Accept(packet.Dst) {
|
||||
if p.allowtrans {
|
||||
// 转发
|
||||
n, err = p.Write(&packet, true)
|
||||
n, err = p.Write(packet, true)
|
||||
if err == nil {
|
||||
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
|
||||
} else {
|
||||
|
||||
@@ -32,10 +32,17 @@ type Me struct {
|
||||
pipe chan *head.Packet
|
||||
// 本机路由表
|
||||
router *Router
|
||||
// 本机未接收完全分片池
|
||||
recving map[[32]byte]*head.Packet
|
||||
recvmu sync.Mutex
|
||||
// 超时定时器
|
||||
clock map[*head.Packet]uint8
|
||||
// 本机上层配置
|
||||
srcport, dstport, mtu uint16
|
||||
}
|
||||
|
||||
// NewMe 设置本机参数
|
||||
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) {
|
||||
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) {
|
||||
m.privKey = *privateKey
|
||||
var err error
|
||||
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
|
||||
@@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei
|
||||
}
|
||||
m.router.SetDefault(nil)
|
||||
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
|
||||
m.srcport = srcport
|
||||
m.dstport = dstport
|
||||
m.mtu = mtu
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Me) SrcPort() uint16 {
|
||||
return m.srcport
|
||||
}
|
||||
|
||||
func (m *Me) DstPort() uint16 {
|
||||
return m.dstport
|
||||
}
|
||||
|
||||
func (m *Me) MTU() uint16 {
|
||||
return m.mtu
|
||||
}
|
||||
|
||||
77
gold/link/recv.go
Normal file
77
gold/link/recv.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/head"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (m *Me) initrecvpool() {
|
||||
if m.recving == nil {
|
||||
m.recving = make(map[[32]byte]*head.Packet, 128)
|
||||
}
|
||||
// 超时定时器
|
||||
m.clock = make(map[*head.Packet]uint8, 128)
|
||||
var delhs []*head.Packet
|
||||
t := time.NewTicker(time.Second)
|
||||
for range t.C {
|
||||
m.recvmu.Lock()
|
||||
for k, v := range m.clock {
|
||||
if v > 10 { // 10s
|
||||
delete(m.recving, k.Hash)
|
||||
delhs = append(delhs, k)
|
||||
} else {
|
||||
m.clock[k]++
|
||||
}
|
||||
}
|
||||
for _, k := range delhs {
|
||||
delete(m.clock, k)
|
||||
logrus.Warnln("[recv] drop timeout packet from", k.Src)
|
||||
}
|
||||
delhs = delhs[:0]
|
||||
m.recvmu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Me) wait(data []byte) *head.Packet {
|
||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||
if flags == 0 || flags == 0x4000 {
|
||||
h := &head.Packet{}
|
||||
_, err := h.Unmarshal(data)
|
||||
if err != nil {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
return nil
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
m.recvmu.Lock()
|
||||
defer m.recvmu.Unlock()
|
||||
hashd := data[20:52]
|
||||
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd)))
|
||||
h, ok := m.recving[hsh]
|
||||
if ok {
|
||||
ok, err := h.Unmarshal(data)
|
||||
if err == nil {
|
||||
if ok {
|
||||
return h
|
||||
}
|
||||
m.clock[h] = 0
|
||||
} else {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h = &head.Packet{}
|
||||
_, err := h.Unmarshal(data)
|
||||
if err != nil {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
return nil
|
||||
}
|
||||
m.recving[hsh] = h
|
||||
m.clock[h] = 0
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user