1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-12 12:50:28 +08:00

parallel listen

This commit is contained in:
源文雨
2022-04-13 11:34:14 +08:00
parent 68b2e8cac5
commit 7d8be16fd3
9 changed files with 165 additions and 144 deletions

View File

@@ -42,16 +42,16 @@ type Packet struct {
} }
// NewPacket 生成一个新包 // NewPacket 生成一个新包
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet { func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) (p *Packet) {
// logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data) // logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
return &Packet{ p = SelectPacket()
Proto: proto, p.Proto = proto
TTL: 16, p.TTL = 16
SrcPort: srcPort, p.SrcPort = srcPort
DstPort: dstPort, p.DstPort = dstPort
Dst: dst, p.Dst = dst
Data: data, p.Data = data
} return
} }
// Unmarshal 将 data 的数据解码到自身 // Unmarshal 将 data 的数据解码到自身
@@ -158,3 +158,8 @@ func (p *Packet) IsVaildHash() bool {
logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:])) logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:]))
return sum == p.Hash return sum == p.Hash
} }
// Put 将自己放回池中
func (p *Packet) Put() {
PutPacket(p)
}

21
gold/head/pool.go Normal file
View File

@@ -0,0 +1,21 @@
package head
import "sync"
var packetPool = sync.Pool{
New: func() interface{} {
return new(Packet)
},
}
// SelectPacket 从池中取出一个 Packet
func SelectPacket() *Packet {
return packetPool.Get().(*Packet)
}
// PutPacket 将 Packet 放回池中
func PutPacket(p *Packet) {
p.TeaTypeDataSZ = 0
p.Data = nil
packetPool.Put(p)
}

View File

@@ -2,7 +2,9 @@ package link
import ( import (
"net" "net"
"runtime"
"strconv" "strconv"
"sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -12,118 +14,118 @@ import (
// 监听本机 endpoint // 监听本机 endpoint
func (m *Me) listen() (conn *net.UDPConn, err error) { func (m *Me) listen() (conn *net.UDPConn, err error) {
conn, err = net.ListenUDP("udp", m.myend) conn, err = net.ListenUDP("udp", m.myend)
if err == nil { if err != nil {
go func() { return
listenbuff := make([]byte, 65536) }
for { var mu sync.Mutex
lbf := listenbuff for i := 0; i < runtime.NumCPU(); i++ {
n, addr, err := conn.ReadFromUDP(lbf) go m.listenthread(conn, &mu)
if err == nil {
lbf = lbf[:n]
packet := m.wait(lbf)
if packet != nil {
sz := packet.TeaTypeDataSZ & 0x00ffffff
r := int(sz) - len(packet.Data)
if r > 0 {
remain, err := readAll(conn, r)
if err == nil {
packet.Data = append(packet.Data, remain...)
}
}
p, ok := m.IsInPeer(packet.Src.String())
logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
// logrus.Debugln("[link] recv:", hex.EncodeToString(lbf))
if ok {
if p.endpoint == nil || p.endpoint.String() != addr.String() {
logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String())
p.endpoint = addr
}
if p.IsToMe(packet.Dst) {
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>24), packet.Data)
if packet.IsVaildHash() {
switch packet.Proto {
case head.ProtoHello:
switch p.status {
case LINK_STATUS_DOWN:
n, err = p.Write(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
if err == nil {
logrus.Debugln("[link] send", n, "bytes hello ack packet")
p.status = LINK_STATUS_HALFUP
} else {
logrus.Errorln("[link] send hello ack packet error:", err)
}
case LINK_STATUS_HALFUP:
p.status = LINK_STATUS_UP
case LINK_STATUS_UP:
break
}
case head.ProtoNotify:
logrus.Infoln("[link] recv notify from", packet.Src)
go p.onNotify(packet.Data)
case head.ProtoQuery:
logrus.Infoln("[link] recv query from", packet.Src)
go p.onQuery(packet.Data)
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet
logrus.Debugln("[link] deliver to pipe of", p.peerip)
} else {
m.nic.Write(packet.Data)
logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic")
}
default:
logrus.Warnln("[link] recv unknown proto:", packet.Proto)
}
} else {
logrus.Debugln("[link] drop invalid packet")
}
} else if p.Accept(packet.Dst) {
if p.allowtrans {
// 转发
lnk := m.router.NextHop(packet.Dst.String())
if lnk != nil {
n, err = lnk.Write(packet, true)
if err == nil {
logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else {
logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
}
} else {
logrus.Warnln("[link] transfer drop packet: nil nexthop")
}
} else {
logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
}
} else {
logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
}
} else {
logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused")
}
}
}
}
}()
} }
return return
} }
// 从 conn 读取 sz 字节数据 func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) {
func readAll(conn *net.UDPConn, sz int) ([]byte, error) { listenbuff := make([]byte, 65536)
i := 0 lbf := listenbuff
n := 0 for {
r := sz lbf = listenbuff
var err error mu.Lock()
remain := make([]byte, r) n, addr, err := conn.ReadFromUDP(lbf)
for sz > 0 { mu.Unlock()
n, _, err = conn.ReadFromUDP(remain[i:]) if err != nil {
if err == nil { continue
i += n }
r -= n lbf = lbf[:n]
} else { packet := m.wait(lbf)
logrus.Errorln("[link] read all err:", err) if packet == nil {
return nil, err continue
}
sz := packet.TeaTypeDataSZ & 0x00ffffff
r := int(sz) - len(packet.Data)
if r > 0 {
logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it")
packet.Put()
continue
}
p, ok := m.IsInPeer(packet.Src.String())
logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
// logrus.Debugln("[link] recv:", hex.EncodeToString(lbf))
if !ok {
logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused")
packet.Put()
continue
}
if p.endpoint == nil || p.endpoint.String() != addr.String() {
logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String())
p.endpoint = addr
}
switch {
case p.IsToMe(packet.Dst):
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>24), packet.Data)
if !packet.IsVaildHash() {
logrus.Debugln("[link] drop invalid packet")
packet.Put()
continue
}
switch packet.Proto {
case head.ProtoHello:
switch p.status {
case LINK_STATUS_DOWN:
n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
if err == nil {
logrus.Debugln("[link] send", n, "bytes hello ack packet")
p.status = LINK_STATUS_HALFUP
} else {
logrus.Errorln("[link] send hello ack packet error:", err)
}
case LINK_STATUS_HALFUP:
p.status = LINK_STATUS_UP
case LINK_STATUS_UP:
}
packet.Put()
case head.ProtoNotify:
logrus.Infoln("[link] recv notify from", packet.Src)
go p.onNotify(packet.Data)
packet.Put()
case head.ProtoQuery:
logrus.Infoln("[link] recv query from", packet.Src)
go p.onQuery(packet.Data)
packet.Put()
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet
logrus.Debugln("[link] deliver to pipe of", p.peerip)
} else {
m.nic.Write(packet.Data)
logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic")
packet.Put()
}
default:
logrus.Warnln("[link] recv unknown proto:", packet.Proto)
packet.Put()
}
case p.Accept(packet.Dst):
if !p.allowtrans {
logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
continue
}
// 转发
lnk := m.router.NextHop(packet.Dst.String())
if lnk == nil {
logrus.Warnln("[link] transfer drop packet: nil nexthop")
packet.Put()
continue
}
n, err = lnk.WriteAndPut(packet, true)
if err == nil {
logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else {
logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
}
default:
logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
packet.Put()
} }
} }
return remain, nil
} }

View File

@@ -177,7 +177,7 @@ func (m *Me) sendAllSameDst(packet []byte) (n int, rem []byte) {
logrus.Warnln("[me] drop packet to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), ": nil nexthop") logrus.Warnln("[me] drop packet to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), ": nil nexthop")
return return
} }
_, err := lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false) _, err := lnk.WriteAndPut(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false)
if err != nil { if err != nil {
logrus.Warnln("[me] write to peer", lnk.peerip, "err:", err) logrus.Warnln("[me] write to peer", lnk.peerip, "err:", err)
} }

View File

@@ -19,7 +19,7 @@ func (l *Link) keepAlive(dur int64) {
logrus.Infoln("[link.nat] start to keep alive") logrus.Infoln("[link.nat] start to keep alive")
t := time.NewTicker(time.Second * time.Duration(dur)) t := time.NewTicker(time.Second * time.Duration(dur))
for range t.C { for range t.C {
n, err := l.Write(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false)
if err == nil { if err == nil {
logrus.Infoln("[link] send", n, "bytes keep alive packet") logrus.Infoln("[link] send", n, "bytes keep alive packet")
} else { } else {
@@ -87,7 +87,7 @@ func (l *Link) onQuery(packet []byte) {
logrus.Infoln("[query] wrap", len(notify), "notify") logrus.Infoln("[query] wrap", len(notify), "notify")
w := helper.SelectWriter() w := helper.SelectWriter()
json.NewEncoder(w).Encode(&notify) json.NewEncoder(w).Encode(&notify)
l.Write(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false) l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false)
helper.PutWriter(w) helper.PutWriter(w)
} }
} }
@@ -104,7 +104,7 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) {
t := time.NewTicker(tick) t := time.NewTicker(tick)
for range t.C { for range t.C {
logrus.Infoln("[query] send query to", l.peerip) logrus.Infoln("[query] send query to", l.peerip)
_, err = l.Write(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false) _, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false)
if err != nil { if err != nil {
logrus.Errorln("[query] write err:", err) logrus.Errorln("[query] write err:", err)
} }

View File

@@ -46,7 +46,7 @@ func (m *Me) wait(data []byte) *head.Packet {
flags := binary.LittleEndian.Uint16(data[10:12]) flags := binary.LittleEndian.Uint16(data[10:12])
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12])) logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12]))
if flags == 0 || flags == 0x4000 { if flags == 0 || flags == 0x4000 {
h := &head.Packet{} h := head.SelectPacket()
_, err := h.Unmarshal(data) _, err := h.Unmarshal(data)
if err != nil { if err != nil {
logrus.Errorln("[recv] unmarshal err:", err) logrus.Errorln("[recv] unmarshal err:", err)
@@ -77,7 +77,7 @@ func (m *Me) wait(data []byte) *head.Packet {
return nil return nil
} }
logrus.Debugln("[recv] get new frag part of", hex.EncodeToString(hashd)) logrus.Debugln("[recv] get new frag part of", hex.EncodeToString(hashd))
h = &head.Packet{} h = head.SelectPacket()
_, err := h.Unmarshal(data) _, err := h.Unmarshal(data)
if err != nil { if err != nil {
logrus.Errorln("[recv] unmarshal err:", err) logrus.Errorln("[recv] unmarshal err:", err)

View File

@@ -9,14 +9,15 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Write 向 peer 发包 // WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) { func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
teatype := uint8(rand.Intn(16)) teatype := uint8(rand.Intn(16))
if len(p.Data) <= int(l.me.mtu) { if len(p.Data) <= int(l.me.mtu) {
if !istransfer { if !istransfer {
p.FillHash() p.FillHash()
p.Data = l.Encode(teatype, p.Data) p.Data = l.Encode(teatype, p.Data)
} }
defer p.Put()
return l.write(p, teatype, uint32(len(p.Data)), 0, istransfer, false) return l.write(p, teatype, uint32(len(p.Data)), 0, istransfer, false)
} }
if !istransfer { if !istransfer {
@@ -24,26 +25,28 @@ func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
p.Data = l.Encode(teatype, p.Data) p.Data = l.Encode(teatype, p.Data)
} }
data := p.Data data := p.Data
ttl := p.TTL
totl := uint32(len(data)) totl := uint32(len(data))
i := 0 i := 0
packet := head.SelectPacket()
*packet = *p
for ; int(totl)-i > int(l.me.mtu); i += int(l.me.mtu) { for ; int(totl)-i > int(l.me.mtu); i += int(l.me.mtu) {
logrus.Debugln("[link] split frag", i, ":", i+int(l.me.mtu), ", remain:", int(totl)-i-int(l.me.mtu)) logrus.Debugln("[link] split frag", i, ":", i+int(l.me.mtu), ", remain:", int(totl)-i-int(l.me.mtu))
packet := *p
packet.Data = data[:int(l.me.mtu)] packet.Data = data[:int(l.me.mtu)]
cnt, err := l.write(&packet, teatype, totl, uint16(uint(i)>>3), istransfer, true) cnt, err := l.write(packet, teatype, totl, uint16(uint(i)>>3), istransfer, true)
n += cnt n += cnt
if err != nil { if err != nil {
return n, err return n, err
} }
data = data[int(l.me.mtu):] data = data[int(l.me.mtu):]
packet.TTL = ttl
} }
packet.Put()
p.Data = data p.Data = data
cnt, err := l.write(p, teatype, totl, uint16(uint(i)>>3), istransfer, false) cnt, err := l.write(p, teatype, totl, uint16(uint(i)>>3), istransfer, false)
p.Put()
n += cnt n += cnt
if err != nil { return n, err
return n, err
}
return n, nil
} }
// write 向 peer 发一个包 // write 向 peer 发一个包

View File

@@ -90,7 +90,7 @@ func (s *Tunnel) handleWrite() {
logrus.Debugln("[tunnel] writing", len(b), "bytes...") logrus.Debugln("[tunnel] writing", len(b), "bytes...")
for len(b) > int(s.mtu) { for len(b) > int(s.mtu) {
logrus.Infoln("[tunnel] split buffer") logrus.Infoln("[tunnel] split buffer")
_, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false) _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false)
if err != nil { if err != nil {
logrus.Errorln("[tunnel] write err:", err) logrus.Errorln("[tunnel] write err:", err)
return return
@@ -98,7 +98,7 @@ func (s *Tunnel) handleWrite() {
logrus.Debugln("[tunnel] write succeeded") logrus.Debugln("[tunnel] write succeeded")
b = b[s.mtu:] b = b[s.mtu:]
} }
_, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false) _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false)
if err != nil { if err != nil {
logrus.Errorln("[tunnel] write err:", err) logrus.Errorln("[tunnel] write err:", err)
break break
@@ -116,5 +116,6 @@ func (s *Tunnel) handleRead() {
} }
logrus.Debugln("[tunnel] read recv", p.Data) logrus.Debugln("[tunnel] read recv", p.Data)
s.out <- p.Data s.out <- p.Data
p.Put()
} }
} }

View File

@@ -60,17 +60,6 @@ func TestTunnel(t *testing.T) {
t.Fatal("error: recv 4096 bytes data") t.Fatal("error: recv 4096 bytes data")
} }
sendb = make([]byte, 131072)
rand.Read(sendb)
tunnme.Write(sendb)
buf = make([]byte, 131072)
for i := 0; i < 32; i++ {
tunnpeer.Read(buf[i*4096:])
}
if string(sendb) != string(buf) {
t.Fatal("error: recv 131072 bytes data")
}
tunnme.Stop() tunnme.Stop()
tunnpeer.Stop() tunnpeer.Stop()
} }