1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-23 03:50:32 +08:00

fix: async wait

This commit is contained in:
源文雨
2023-08-05 13:53:09 +08:00
parent a3ae280a7f
commit 1caf27dfa9
9 changed files with 252 additions and 156 deletions

View File

@@ -101,6 +101,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
if p.rembytes > 0 {
p.rembytes -= copy(p.Data[flags<<3:], data[60:])
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
}
complete = p.rembytes == 0

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"unsafe"
@@ -26,146 +25,163 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
}
m.myend = conn.LocalAddr()
logrus.Infoln("[listen] at", m.myend)
var mu sync.Mutex
for i := 0; i < runtime.NumCPU()*4; i++ {
go m.listenthread(conn, &mu, i)
}
go func() {
recvtotlcnt := 0
recvloopcnt := 0
recvlooptime := time.Now().UnixMilli()
n := runtime.NumCPU()
if n > 64 {
n = 64 // 只用最多 64 核
}
logrus.Infoln("[listen] use cpu num:", n)
listenbuff := make([]byte, 65536*n)
hasntfinished := make([]bool, n)
for i := 0; err == nil; i++ {
i %= n
for hasntfinished[i] {
time.Sleep(time.Millisecond)
i++
i %= n
}
lbf := listenbuff[i*65536 : (i+1)*65536]
n, addr, err := conn.ReadFromUDP(lbf)
if err != nil {
logrus.Warnln("[listen] read from udp err, reconnect:", err)
conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.myend.String())))
if err != nil {
logrus.Errorln("[listen] reconnect udp err:", err)
return
}
i--
continue
}
recvtotlcnt += len(lbf)
recvloopcnt++
if recvloopcnt >= 4096 {
now := time.Now().UnixMilli()
logrus.Infof("[listen] recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime))
recvtotlcnt = 0
recvloopcnt = 0
recvlooptime = now
}
packet := m.wait(lbf[:n])
if packet == nil {
i--
continue
}
hasntfinished[i] = true
go m.listenthread(packet, addr, i, func() { hasntfinished[i] = false })
}
}()
return
}
func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex, index int) {
listenbuff := make([]byte, 65536)
lbf := listenbuff
recvtotlcnt := 0
recvloopcnt := 0
recvlooptime := time.Now().UnixMilli()
for {
lbf = listenbuff
mu.Lock()
n, addr, err := conn.ReadFromUDP(lbf)
mu.Unlock()
if err != nil {
logrus.Errorln("[listen] thread", index, "read from udp err:", err)
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
defer finish()
sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data)
if r > 0 {
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
packet.Put()
return
}
p, ok := m.IsInPeer(packet.Src.String())
logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
if !ok {
logrus.Warnln("[listen] @", index, "packet from", packet.Src, "to", packet.Dst, "is refused")
packet.Put()
return
}
if p.endpoint == nil || p.endpoint.String() != addr.String() {
logrus.Infoln("[listen] @", index, "set endpoint of peer", p.peerip, "to", addr.String())
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.endpoint)), unsafe.Pointer(addr))
}
switch {
case p.IsToMe(packet.Dst):
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data)
if p.aead != nil {
addt := packet.AdditionalData()
packet.Data = p.DecodePreshared(addt, packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] @", index, "drop invalid preshared packet, addt:", addt)
packet.Put()
return
}
}
if p.usezstd {
dec, _ := zstd.NewReader(bytes.NewReader(packet.Data))
var err error
packet.Data, err = io.ReadAll(dec)
dec.Close()
if err != nil {
logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err)
packet.Put()
return
}
}
if !packet.IsVaildHash() {
logrus.Debugln("[listen] @", index, "drop invalid hash packet")
packet.Put()
return
}
lbf = lbf[:n]
recvtotlcnt += n
recvloopcnt++
if recvloopcnt >= 4096 {
now := time.Now().UnixMilli()
logrus.Infof("[listen] thread %d recv speed: %.2f B/s", index, float64(recvtotlcnt*1000)/float64(now-recvlooptime))
recvtotlcnt = 0
recvloopcnt = 0
recvlooptime = now
}
packet := m.wait(lbf)
if packet == nil {
continue
}
sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data)
if r > 0 {
logrus.Warnln("[listen] thread", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
packet.Put()
continue
}
p, ok := m.IsInPeer(packet.Src.String())
logrus.Debugln("[listen] thread", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
if !ok {
logrus.Warnln("[listen] thread", index, "packet from", packet.Src, "to", packet.Dst, "is refused")
packet.Put()
continue
}
if p.endpoint == nil || p.endpoint.String() != addr.String() {
logrus.Infoln("[listen] thread", index, "set endpoint of peer", p.peerip, "to", addr.String())
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.endpoint)), unsafe.Pointer(addr))
}
switch {
case p.IsToMe(packet.Dst):
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data)
if p.aead != nil {
addt := packet.AdditionalData()
packet.Data = p.DecodePreshared(addt, packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] thread", index, "drop invalid preshared packet, addt:", addt)
packet.Put()
continue
}
}
if p.usezstd {
dec, _ := zstd.NewReader(bytes.NewReader(packet.Data))
packet.Data, err = io.ReadAll(dec)
dec.Close()
if err != nil {
logrus.Debugln("[listen] thread", index, "drop invalid zstd packet:", err)
packet.Put()
continue
}
}
if !packet.IsVaildHash() {
logrus.Debugln("[listen] thread", index, "drop invalid hash packet")
packet.Put()
continue
}
switch packet.Proto {
case head.ProtoHello:
switch p.status {
case LINK_STATUS_DOWN:
n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
if err == nil {
logrus.Debugln("[listen] thread", index, "send", n, "bytes hello ack packet")
p.status = LINK_STATUS_HALFUP
} else {
logrus.Errorln("[listen] thread", index, "send hello ack packet error:", err)
}
case LINK_STATUS_HALFUP:
p.status = LINK_STATUS_UP
case LINK_STATUS_UP:
}
packet.Put()
case head.ProtoNotify:
logrus.Infoln("[listen] thread", index, "recv notify from", packet.Src)
go p.onNotify(packet.Data)
packet.Put()
case head.ProtoQuery:
logrus.Infoln("[listen] thread", index, "recv query from", packet.Src)
go p.onQuery(packet.Data)
packet.Put()
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet
logrus.Debugln("[listen] thread", index, "deliver to pipe of", p.peerip)
switch packet.Proto {
case head.ProtoHello:
switch p.status {
case LINK_STATUS_DOWN:
n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
if err == nil {
logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet")
p.status = LINK_STATUS_HALFUP
} else {
m.nic.Write(packet.Data)
logrus.Debugln("[listen] thread", index, "deliver", len(packet.Data), "bytes data to nic")
packet.Put()
logrus.Errorln("[listen] @", index, "send hello ack packet error:", err)
}
default:
logrus.Warnln("[listen] thread", index, "recv unknown proto:", packet.Proto)
packet.Put()
case LINK_STATUS_HALFUP:
p.status = LINK_STATUS_UP
case LINK_STATUS_UP:
}
case p.Accept(packet.Dst):
if !p.allowtrans {
logrus.Warnln("[listen] thread", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
continue
}
// 转发
lnk := m.router.NextHop(packet.Dst.String())
if lnk == nil {
logrus.Warnln("[listen] thread", index, "transfer drop packet: nil nexthop")
packet.Put()
continue
}
n, err = lnk.WriteAndPut(packet, true)
if err == nil {
logrus.Debugln("[listen] thread", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
case head.ProtoNotify:
logrus.Infoln("[listen] @", index, "recv notify from", packet.Src)
go p.onNotify(packet.Data)
packet.Put()
case head.ProtoQuery:
logrus.Infoln("[listen] @", index, "recv query from", packet.Src)
go p.onQuery(packet.Data)
packet.Put()
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet
logrus.Debugln("[listen] @", index, "deliver to pipe of", p.peerip)
} else {
logrus.Errorln("[listen] thread", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
m.nic.Write(packet.Data)
logrus.Debugln("[listen] @", index, "deliver", len(packet.Data), "bytes data to nic")
packet.Put()
}
default:
logrus.Warnln("[listen] thread", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
logrus.Warnln("[listen] @", index, "recv unknown proto:", packet.Proto)
packet.Put()
}
case p.Accept(packet.Dst):
if !p.allowtrans {
logrus.Warnln("[listen] @", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
return
}
// 转发
lnk := m.router.NextHop(packet.Dst.String())
if lnk == nil {
logrus.Warnln("[listen] @", index, "transfer drop packet: nil nexthop")
packet.Put()
return
}
n, err := lnk.WriteAndPut(packet, true)
if err == nil {
logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else {
logrus.Errorln("[listen] @", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
}
default:
logrus.Warnln("[listen] @", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
packet.Put()
}
}

View File

@@ -44,7 +44,7 @@ type Me struct {
// 本机未接收完全分片池
recving *ttl.Cache[[32]byte, *head.Packet]
// 抗重放攻击记录池
recved *ttl.Cache[uint64, uint8]
recved *ttl.Cache[uint64, bool]
// 本机上层配置
srcport, dstport, mtu uint16
// 报头掩码
@@ -101,7 +101,7 @@ func NewMe(cfg *MyConfig) (m Me) {
binary.BigEndian.PutUint64(buf[:], m.mask)
logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:]))
m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 30)
m.recved = ttl.NewCache[uint64, uint8](time.Second * 30)
m.recved = ttl.NewCache[uint64, bool](time.Second * 30)
return
}

View File

@@ -28,14 +28,15 @@ func (m *Me) wait(data []byte) *head.Packet {
data = m.xordec(data)
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
flags := binary.LittleEndian.Uint16(data[10:12])
if flags&0x8000 == 0x8000 { // not a valid packet
if flags&0x8000 != 0 { // not a valid packet
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
return nil
}
crc := binary.LittleEndian.Uint64(data[52:60])
if m.recved.Get(crc) != 0 { // 是重放攻击
if m.recved.Get(crc) { // 是重放攻击
return nil
}
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12]))
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
if flags == 0 || flags == 0x4000 {
h := head.SelectPacket()
_, err := h.Unmarshal(data)
@@ -43,7 +44,7 @@ func (m *Me) wait(data []byte) *head.Packet {
logrus.Errorln("[recv] unmarshal err:", err)
return nil
}
m.recved.Set(crc, 1)
m.recved.Set(crc, true)
return h
}
@@ -56,8 +57,8 @@ func (m *Me) wait(data []byte) *head.Packet {
if err == nil {
if ok {
m.recving.Delete(hsh)
m.recved.Set(crc, 1)
logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "is reached")
m.recved.Set(crc, true)
logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "has reached")
return h
}
} else {

View File

@@ -17,6 +17,7 @@ import (
// WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
defer p.Put()
teatype := uint8(rand.Intn(16))
sndcnt := atomic.AddUintptr(&l.sendcount, 1)
mtu := l.mtu
@@ -33,22 +34,21 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
delta = 8
}
if len(p.Data) <= delta {
defer p.Put()
return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false)
}
if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta {
return 0, errors.New("drop dont fragmnet big trans packet")
return 0, errors.New("drop don't fragmnet big trans packet")
}
data := p.Data
ttl := p.TTL
totl := uint32(len(data))
i := 0
pos := 0
packet := head.SelectPacket()
*packet = *p
for ; int(totl)-i > delta; i += delta {
logrus.Debugln("[send] split frag [", i, "~", i+delta, "], remain:", int(totl)-i-delta)
for ; int(totl)-pos > delta; pos += delta {
logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", int(totl)-pos-delta)
packet.Data = data[:delta]
cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, true)
cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, true)
n += cnt
if err != nil {
return n, err
@@ -57,10 +57,12 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
packet.TTL = ttl
}
packet.Put()
p.Data = data
cnt, err := l.write(p, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, false)
p.Put()
n += cnt
if len(data) > 0 {
p.Data = data
cnt := 0
cnt, err = l.write(p, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, false)
n += cnt
}
return n, err
}