From 7d8be16fd38afa48d3d22578fdcd365370930992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 13 Apr 2022 11:34:14 +0800 Subject: [PATCH] parallel listen --- gold/head/packet.go | 23 +-- gold/head/pool.go | 21 +++ gold/link/listen.go | 218 ++++++++++++++------------- gold/link/me.go | 2 +- gold/link/nat.go | 6 +- gold/link/recv.go | 4 +- gold/link/send.go | 19 ++- upper/services/tunnel/tunnel.go | 5 +- upper/services/tunnel/tunnel_test.go | 11 -- 9 files changed, 165 insertions(+), 144 deletions(-) create mode 100644 gold/head/pool.go diff --git a/gold/head/packet.go b/gold/head/packet.go index 630c2c2..08208aa 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -42,16 +42,16 @@ type Packet struct { } // NewPacket 生成一个新包 -func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet { +func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) (p *Packet) { // logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data) - return &Packet{ - Proto: proto, - TTL: 16, - SrcPort: srcPort, - DstPort: dstPort, - Dst: dst, - Data: data, - } + p = SelectPacket() + p.Proto = proto + p.TTL = 16 + p.SrcPort = srcPort + p.DstPort = dstPort + p.Dst = dst + p.Data = data + return } // Unmarshal 将 data 的数据解码到自身 @@ -158,3 +158,8 @@ func (p *Packet) IsVaildHash() bool { logrus.Debugln("[packet] sum in packet:", hex.EncodeToString(p.Hash[:])) return sum == p.Hash } + +// Put 将自己放回池中 +func (p *Packet) Put() { + PutPacket(p) +} diff --git a/gold/head/pool.go b/gold/head/pool.go new file mode 100644 index 0000000..0a90123 --- /dev/null +++ b/gold/head/pool.go @@ -0,0 +1,21 @@ +package head + +import "sync" + +var packetPool = sync.Pool{ + New: func() interface{} { + return new(Packet) + }, +} + +// SelectPacket 从池中取出一个 Packet +func SelectPacket() *Packet { + return packetPool.Get().(*Packet) +} + +// PutPacket 将 Packet 放回池中 +func PutPacket(p *Packet) { + p.TeaTypeDataSZ = 0 + p.Data = nil + packetPool.Put(p) +} diff --git a/gold/link/listen.go b/gold/link/listen.go index 87d5353..611de72 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -2,7 +2,9 @@ package link import ( "net" + "runtime" "strconv" + "sync" "github.com/sirupsen/logrus" @@ -12,118 +14,118 @@ import ( // 监听本机 endpoint func (m *Me) listen() (conn *net.UDPConn, err error) { conn, err = net.ListenUDP("udp", m.myend) - if err == nil { - go func() { - listenbuff := make([]byte, 65536) - for { - lbf := listenbuff - n, addr, err := conn.ReadFromUDP(lbf) - if err == nil { - lbf = lbf[:n] - packet := m.wait(lbf) - if packet != nil { - sz := packet.TeaTypeDataSZ & 0x00ffffff - r := int(sz) - len(packet.Data) - if r > 0 { - remain, err := readAll(conn, r) - if err == nil { - packet.Data = append(packet.Data, remain...) - } - } - p, ok := m.IsInPeer(packet.Src.String()) - logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) - // logrus.Debugln("[link] recv:", hex.EncodeToString(lbf)) - if ok { - if p.endpoint == nil || p.endpoint.String() != addr.String() { - logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String()) - p.endpoint = addr - } - if p.IsToMe(packet.Dst) { - packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>24), packet.Data) - if packet.IsVaildHash() { - switch packet.Proto { - case head.ProtoHello: - switch p.status { - case LINK_STATUS_DOWN: - n, err = p.Write(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) - if err == nil { - logrus.Debugln("[link] send", n, "bytes hello ack packet") - p.status = LINK_STATUS_HALFUP - } else { - logrus.Errorln("[link] send hello ack packet error:", err) - } - case LINK_STATUS_HALFUP: - p.status = LINK_STATUS_UP - case LINK_STATUS_UP: - break - } - case head.ProtoNotify: - logrus.Infoln("[link] recv notify from", packet.Src) - go p.onNotify(packet.Data) - case head.ProtoQuery: - logrus.Infoln("[link] recv query from", packet.Src) - go p.onQuery(packet.Data) - case head.ProtoData: - if p.pipe != nil { - p.pipe <- packet - logrus.Debugln("[link] deliver to pipe of", p.peerip) - } else { - m.nic.Write(packet.Data) - logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic") - } - default: - logrus.Warnln("[link] recv unknown proto:", packet.Proto) - } - } else { - logrus.Debugln("[link] drop invalid packet") - } - } else if p.Accept(packet.Dst) { - if p.allowtrans { - // 转发 - lnk := m.router.NextHop(packet.Dst.String()) - if lnk != nil { - n, err = lnk.Write(packet, true) - if err == nil { - logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) - } else { - logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) - } - } else { - logrus.Warnln("[link] transfer drop packet: nil nexthop") - } - } else { - logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) - } - } else { - logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") - } - } else { - logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused") - } - } - } - } - }() + if err != nil { + return + } + var mu sync.Mutex + for i := 0; i < runtime.NumCPU(); i++ { + go m.listenthread(conn, &mu) } return } -// 从 conn 读取 sz 字节数据 -func readAll(conn *net.UDPConn, sz int) ([]byte, error) { - i := 0 - n := 0 - r := sz - var err error - remain := make([]byte, r) - for sz > 0 { - n, _, err = conn.ReadFromUDP(remain[i:]) - if err == nil { - i += n - r -= n - } else { - logrus.Errorln("[link] read all err:", err) - return nil, err +func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) { + listenbuff := make([]byte, 65536) + lbf := listenbuff + for { + lbf = listenbuff + mu.Lock() + n, addr, err := conn.ReadFromUDP(lbf) + mu.Unlock() + if err != nil { + continue + } + lbf = lbf[:n] + packet := m.wait(lbf) + if packet == nil { + continue + } + sz := packet.TeaTypeDataSZ & 0x00ffffff + r := int(sz) - len(packet.Data) + if r > 0 { + logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it") + packet.Put() + continue + } + p, ok := m.IsInPeer(packet.Src.String()) + logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) + // logrus.Debugln("[link] recv:", hex.EncodeToString(lbf)) + if !ok { + logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused") + packet.Put() + continue + } + if p.endpoint == nil || p.endpoint.String() != addr.String() { + logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String()) + p.endpoint = addr + } + switch { + case p.IsToMe(packet.Dst): + packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>24), packet.Data) + if !packet.IsVaildHash() { + logrus.Debugln("[link] drop invalid packet") + packet.Put() + continue + } + switch packet.Proto { + case head.ProtoHello: + switch p.status { + case LINK_STATUS_DOWN: + n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) + if err == nil { + logrus.Debugln("[link] send", n, "bytes hello ack packet") + p.status = LINK_STATUS_HALFUP + } else { + logrus.Errorln("[link] send hello ack packet error:", err) + } + case LINK_STATUS_HALFUP: + p.status = LINK_STATUS_UP + case LINK_STATUS_UP: + } + packet.Put() + case head.ProtoNotify: + logrus.Infoln("[link] recv notify from", packet.Src) + go p.onNotify(packet.Data) + packet.Put() + case head.ProtoQuery: + logrus.Infoln("[link] recv query from", packet.Src) + go p.onQuery(packet.Data) + packet.Put() + case head.ProtoData: + if p.pipe != nil { + p.pipe <- packet + logrus.Debugln("[link] deliver to pipe of", p.peerip) + } else { + m.nic.Write(packet.Data) + logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic") + packet.Put() + } + default: + logrus.Warnln("[link] recv unknown proto:", packet.Proto) + packet.Put() + } + case p.Accept(packet.Dst): + if !p.allowtrans { + logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + packet.Put() + continue + } + // 转发 + lnk := m.router.NextHop(packet.Dst.String()) + if lnk == nil { + logrus.Warnln("[link] transfer drop packet: nil nexthop") + packet.Put() + continue + } + n, err = lnk.WriteAndPut(packet, true) + if err == nil { + logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + } else { + logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) + } + default: + logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") + packet.Put() } } - return remain, nil } diff --git a/gold/link/me.go b/gold/link/me.go index 5c543c9..9cfc0d2 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -177,7 +177,7 @@ func (m *Me) sendAllSameDst(packet []byte) (n int, rem []byte) { logrus.Warnln("[me] drop packet to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), ": nil nexthop") return } - _, err := lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false) + _, err := lnk.WriteAndPut(head.NewPacket(head.ProtoData, m.SrcPort(), lnk.peerip, m.DstPort(), packet), false) if err != nil { logrus.Warnln("[me] write to peer", lnk.peerip, "err:", err) } diff --git a/gold/link/nat.go b/gold/link/nat.go index 45b72f5..f51d2b9 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -19,7 +19,7 @@ func (l *Link) keepAlive(dur int64) { logrus.Infoln("[link.nat] start to keep alive") t := time.NewTicker(time.Second * time.Duration(dur)) for range t.C { - n, err := l.Write(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) + n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) if err == nil { logrus.Infoln("[link] send", n, "bytes keep alive packet") } else { @@ -87,7 +87,7 @@ func (l *Link) onQuery(packet []byte) { logrus.Infoln("[query] wrap", len(notify), "notify") w := helper.SelectWriter() json.NewEncoder(w).Encode(¬ify) - l.Write(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false) + l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false) helper.PutWriter(w) } } @@ -104,7 +104,7 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) { t := time.NewTicker(tick) for range t.C { logrus.Infoln("[query] send query to", l.peerip) - _, err = l.Write(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false) + _, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false) if err != nil { logrus.Errorln("[query] write err:", err) } diff --git a/gold/link/recv.go b/gold/link/recv.go index 74c7be3..d00794d 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -46,7 +46,7 @@ func (m *Me) wait(data []byte) *head.Packet { flags := binary.LittleEndian.Uint16(data[10:12]) logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12])) if flags == 0 || flags == 0x4000 { - h := &head.Packet{} + h := head.SelectPacket() _, err := h.Unmarshal(data) if err != nil { logrus.Errorln("[recv] unmarshal err:", err) @@ -77,7 +77,7 @@ func (m *Me) wait(data []byte) *head.Packet { return nil } logrus.Debugln("[recv] get new frag part of", hex.EncodeToString(hashd)) - h = &head.Packet{} + h = head.SelectPacket() _, err := h.Unmarshal(data) if err != nil { logrus.Errorln("[recv] unmarshal err:", err) diff --git a/gold/link/send.go b/gold/link/send.go index a13215d..3a61598 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -9,14 +9,15 @@ import ( "github.com/sirupsen/logrus" ) -// Write 向 peer 发包 -func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) { +// WriteAndPut 向 peer 发包并将包放回缓存池 +func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { teatype := uint8(rand.Intn(16)) if len(p.Data) <= int(l.me.mtu) { if !istransfer { p.FillHash() p.Data = l.Encode(teatype, p.Data) } + defer p.Put() return l.write(p, teatype, uint32(len(p.Data)), 0, istransfer, false) } if !istransfer { @@ -24,26 +25,28 @@ func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) { p.Data = l.Encode(teatype, p.Data) } data := p.Data + ttl := p.TTL totl := uint32(len(data)) i := 0 + packet := head.SelectPacket() + *packet = *p for ; int(totl)-i > int(l.me.mtu); i += int(l.me.mtu) { logrus.Debugln("[link] split frag", i, ":", i+int(l.me.mtu), ", remain:", int(totl)-i-int(l.me.mtu)) - packet := *p packet.Data = data[:int(l.me.mtu)] - cnt, err := l.write(&packet, teatype, totl, uint16(uint(i)>>3), istransfer, true) + cnt, err := l.write(packet, teatype, totl, uint16(uint(i)>>3), istransfer, true) n += cnt if err != nil { return n, err } data = data[int(l.me.mtu):] + packet.TTL = ttl } + packet.Put() p.Data = data cnt, err := l.write(p, teatype, totl, uint16(uint(i)>>3), istransfer, false) + p.Put() n += cnt - if err != nil { - return n, err - } - return n, nil + return n, err } // write 向 peer 发一个包 diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index c81b84e..497cc77 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -90,7 +90,7 @@ func (s *Tunnel) handleWrite() { logrus.Debugln("[tunnel] writing", len(b), "bytes...") for len(b) > int(s.mtu) { logrus.Infoln("[tunnel] split buffer") - _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false) + _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false) if err != nil { logrus.Errorln("[tunnel] write err:", err) return @@ -98,7 +98,7 @@ func (s *Tunnel) handleWrite() { logrus.Debugln("[tunnel] write succeeded") b = b[s.mtu:] } - _, err := s.l.Write(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false) + _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false) if err != nil { logrus.Errorln("[tunnel] write err:", err) break @@ -116,5 +116,6 @@ func (s *Tunnel) handleRead() { } logrus.Debugln("[tunnel] read recv", p.Data) s.out <- p.Data + p.Put() } } diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index d6be4e3..9a99a45 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -60,17 +60,6 @@ func TestTunnel(t *testing.T) { t.Fatal("error: recv 4096 bytes data") } - sendb = make([]byte, 131072) - rand.Read(sendb) - tunnme.Write(sendb) - buf = make([]byte, 131072) - for i := 0; i < 32; i++ { - tunnpeer.Read(buf[i*4096:]) - } - if string(sendb) != string(buf) { - t.Fatal("error: recv 131072 bytes data") - } - tunnme.Stop() tunnpeer.Stop() }