diff --git a/config/cfg.go b/config/cfg.go index c48fac0..9e3e7f0 100644 --- a/config/cfg.go +++ b/config/cfg.go @@ -13,7 +13,6 @@ type EndPoint struct { Host string `yaml:"Host"` Port int64 `yaml:"Port"` Poly uint64 `yaml:"Poly"` // Poly 是 port 随机切换算法的生成多项式, 0 为禁用 - Protocol string `yaml:"Protocol"` // Protocol is udp/tcp ReconnectSeconds int64 `yaml:"ReconnectSeconds"` // ReconnectSeconds 断开重连间隔, 每次到时即向对端通报并切换到新的端口, 0 为禁用 FECMethod string `yaml:"FECMethod"` // FECMethod 可选 1/2 2/3 } diff --git a/gold/head/packet.go b/gold/head/packet.go index 73dd79a..68617e9 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -101,6 +101,7 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { if p.rembytes > 0 { p.rembytes -= copy(p.Data[flags<<3:], data[60:]) + logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes) } complete = p.rembytes == 0 diff --git a/gold/link/listen.go b/gold/link/listen.go index 68b864c..f780b5b 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -7,7 +7,6 @@ import ( "net/netip" "runtime" "strconv" - "sync" "sync/atomic" "time" "unsafe" @@ -26,146 +25,163 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { } m.myend = conn.LocalAddr() logrus.Infoln("[listen] at", m.myend) - var mu sync.Mutex - for i := 0; i < runtime.NumCPU()*4; i++ { - go m.listenthread(conn, &mu, i) - } + go func() { + recvtotlcnt := 0 + recvloopcnt := 0 + recvlooptime := time.Now().UnixMilli() + n := runtime.NumCPU() + if n > 64 { + n = 64 // 只用最多 64 核 + } + logrus.Infoln("[listen] use cpu num:", n) + listenbuff := make([]byte, 65536*n) + hasntfinished := make([]bool, n) + for i := 0; err == nil; i++ { + i %= n + for hasntfinished[i] { + time.Sleep(time.Millisecond) + i++ + i %= n + } + lbf := listenbuff[i*65536 : (i+1)*65536] + n, addr, err := conn.ReadFromUDP(lbf) + if err != nil { + logrus.Warnln("[listen] read from udp err, reconnect:", err) + conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.myend.String()))) + if err != nil { + logrus.Errorln("[listen] reconnect udp err:", err) + return + } + i-- + continue + } + recvtotlcnt += len(lbf) + recvloopcnt++ + if recvloopcnt >= 4096 { + now := time.Now().UnixMilli() + logrus.Infof("[listen] recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime)) + recvtotlcnt = 0 + recvloopcnt = 0 + recvlooptime = now + } + packet := m.wait(lbf[:n]) + if packet == nil { + i-- + continue + } + hasntfinished[i] = true + go m.listenthread(packet, addr, i, func() { hasntfinished[i] = false }) + } + }() return } -func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex, index int) { - listenbuff := make([]byte, 65536) - lbf := listenbuff - recvtotlcnt := 0 - recvloopcnt := 0 - recvlooptime := time.Now().UnixMilli() - for { - lbf = listenbuff - mu.Lock() - n, addr, err := conn.ReadFromUDP(lbf) - mu.Unlock() - if err != nil { - logrus.Errorln("[listen] thread", index, "read from udp err:", err) +func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) { + defer finish() + sz := packet.TeaTypeDataSZ & 0x0000ffff + r := int(sz) - len(packet.Data) + if r > 0 { + logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it") + packet.Put() + return + } + p, ok := m.IsInPeer(packet.Src.String()) + logrus.Debugln("[listen] @", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) + if !ok { + logrus.Warnln("[listen] @", index, "packet from", packet.Src, "to", packet.Dst, "is refused") + packet.Put() + return + } + if p.endpoint == nil || p.endpoint.String() != addr.String() { + logrus.Infoln("[listen] @", index, "set endpoint of peer", p.peerip, "to", addr.String()) + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.endpoint)), unsafe.Pointer(addr)) + } + switch { + case p.IsToMe(packet.Dst): + packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data) + if p.aead != nil { + addt := packet.AdditionalData() + packet.Data = p.DecodePreshared(addt, packet.Data) + if packet.Data == nil { + logrus.Debugln("[listen] @", index, "drop invalid preshared packet, addt:", addt) + packet.Put() + return + } + } + if p.usezstd { + dec, _ := zstd.NewReader(bytes.NewReader(packet.Data)) + var err error + packet.Data, err = io.ReadAll(dec) + dec.Close() + if err != nil { + logrus.Debugln("[listen] @", index, "drop invalid zstd packet:", err) + packet.Put() + return + } + } + if !packet.IsVaildHash() { + logrus.Debugln("[listen] @", index, "drop invalid hash packet") + packet.Put() return } - lbf = lbf[:n] - recvtotlcnt += n - recvloopcnt++ - if recvloopcnt >= 4096 { - now := time.Now().UnixMilli() - logrus.Infof("[listen] thread %d recv speed: %.2f B/s", index, float64(recvtotlcnt*1000)/float64(now-recvlooptime)) - recvtotlcnt = 0 - recvloopcnt = 0 - recvlooptime = now - } - packet := m.wait(lbf) - if packet == nil { - continue - } - sz := packet.TeaTypeDataSZ & 0x0000ffff - r := int(sz) - len(packet.Data) - if r > 0 { - logrus.Warnln("[listen] thread", index, "packet from endpoint", addr, "is smaller than it declared: drop it") - packet.Put() - continue - } - p, ok := m.IsInPeer(packet.Src.String()) - logrus.Debugln("[listen] thread", index, "recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) - if !ok { - logrus.Warnln("[listen] thread", index, "packet from", packet.Src, "to", packet.Dst, "is refused") - packet.Put() - continue - } - if p.endpoint == nil || p.endpoint.String() != addr.String() { - logrus.Infoln("[listen] thread", index, "set endpoint of peer", p.peerip, "to", addr.String()) - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.endpoint)), unsafe.Pointer(addr)) - } - switch { - case p.IsToMe(packet.Dst): - packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data) - if p.aead != nil { - addt := packet.AdditionalData() - packet.Data = p.DecodePreshared(addt, packet.Data) - if packet.Data == nil { - logrus.Debugln("[listen] thread", index, "drop invalid preshared packet, addt:", addt) - packet.Put() - continue - } - } - if p.usezstd { - dec, _ := zstd.NewReader(bytes.NewReader(packet.Data)) - packet.Data, err = io.ReadAll(dec) - dec.Close() - if err != nil { - logrus.Debugln("[listen] thread", index, "drop invalid zstd packet:", err) - packet.Put() - continue - } - } - if !packet.IsVaildHash() { - logrus.Debugln("[listen] thread", index, "drop invalid hash packet") - packet.Put() - continue - } - switch packet.Proto { - case head.ProtoHello: - switch p.status { - case LINK_STATUS_DOWN: - n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) - if err == nil { - logrus.Debugln("[listen] thread", index, "send", n, "bytes hello ack packet") - p.status = LINK_STATUS_HALFUP - } else { - logrus.Errorln("[listen] thread", index, "send hello ack packet error:", err) - } - case LINK_STATUS_HALFUP: - p.status = LINK_STATUS_UP - case LINK_STATUS_UP: - } - packet.Put() - case head.ProtoNotify: - logrus.Infoln("[listen] thread", index, "recv notify from", packet.Src) - go p.onNotify(packet.Data) - packet.Put() - case head.ProtoQuery: - logrus.Infoln("[listen] thread", index, "recv query from", packet.Src) - go p.onQuery(packet.Data) - packet.Put() - case head.ProtoData: - if p.pipe != nil { - p.pipe <- packet - logrus.Debugln("[listen] thread", index, "deliver to pipe of", p.peerip) + switch packet.Proto { + case head.ProtoHello: + switch p.status { + case LINK_STATUS_DOWN: + n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) + if err == nil { + logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet") + p.status = LINK_STATUS_HALFUP } else { - m.nic.Write(packet.Data) - logrus.Debugln("[listen] thread", index, "deliver", len(packet.Data), "bytes data to nic") - packet.Put() + logrus.Errorln("[listen] @", index, "send hello ack packet error:", err) } - default: - logrus.Warnln("[listen] thread", index, "recv unknown proto:", packet.Proto) - packet.Put() + case LINK_STATUS_HALFUP: + p.status = LINK_STATUS_UP + case LINK_STATUS_UP: } - case p.Accept(packet.Dst): - if !p.allowtrans { - logrus.Warnln("[listen] thread", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) - packet.Put() - continue - } - // 转发 - lnk := m.router.NextHop(packet.Dst.String()) - if lnk == nil { - logrus.Warnln("[listen] thread", index, "transfer drop packet: nil nexthop") - packet.Put() - continue - } - n, err = lnk.WriteAndPut(packet, true) - if err == nil { - logrus.Debugln("[listen] thread", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + packet.Put() + case head.ProtoNotify: + logrus.Infoln("[listen] @", index, "recv notify from", packet.Src) + go p.onNotify(packet.Data) + packet.Put() + case head.ProtoQuery: + logrus.Infoln("[listen] @", index, "recv query from", packet.Src) + go p.onQuery(packet.Data) + packet.Put() + case head.ProtoData: + if p.pipe != nil { + p.pipe <- packet + logrus.Debugln("[listen] @", index, "deliver to pipe of", p.peerip) } else { - logrus.Errorln("[listen] thread", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) + m.nic.Write(packet.Data) + logrus.Debugln("[listen] @", index, "deliver", len(packet.Data), "bytes data to nic") + packet.Put() } default: - logrus.Warnln("[listen] thread", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") + logrus.Warnln("[listen] @", index, "recv unknown proto:", packet.Proto) packet.Put() } + case p.Accept(packet.Dst): + if !p.allowtrans { + logrus.Warnln("[listen] @", index, "refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + packet.Put() + return + } + // 转发 + lnk := m.router.NextHop(packet.Dst.String()) + if lnk == nil { + logrus.Warnln("[listen] @", index, "transfer drop packet: nil nexthop") + packet.Put() + return + } + n, err := lnk.WriteAndPut(packet, true) + if err == nil { + logrus.Debugln("[listen] @", index, "trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + } else { + logrus.Errorln("[listen] @", index, "trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) + } + default: + logrus.Warnln("[listen] @", index, "packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") + packet.Put() } } diff --git a/gold/link/me.go b/gold/link/me.go index 82e8d2f..d0a04f5 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -44,7 +44,7 @@ type Me struct { // 本机未接收完全分片池 recving *ttl.Cache[[32]byte, *head.Packet] // 抗重放攻击记录池 - recved *ttl.Cache[uint64, uint8] + recved *ttl.Cache[uint64, bool] // 本机上层配置 srcport, dstport, mtu uint16 // 报头掩码 @@ -101,7 +101,7 @@ func NewMe(cfg *MyConfig) (m Me) { binary.BigEndian.PutUint64(buf[:], m.mask) logrus.Infoln("[me] xor mask", hex.EncodeToString(buf[:])) m.recving = ttl.NewCache[[32]byte, *head.Packet](time.Second * 30) - m.recved = ttl.NewCache[uint64, uint8](time.Second * 30) + m.recved = ttl.NewCache[uint64, bool](time.Second * 30) return } diff --git a/gold/link/recv.go b/gold/link/recv.go index b8e28a7..50633a7 100644 --- a/gold/link/recv.go +++ b/gold/link/recv.go @@ -28,14 +28,15 @@ func (m *Me) wait(data []byte) *head.Packet { data = m.xordec(data) logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl) flags := binary.LittleEndian.Uint16(data[10:12]) - if flags&0x8000 == 0x8000 { // not a valid packet + if flags&0x8000 != 0 { // not a valid packet + logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) return nil } crc := binary.LittleEndian.Uint64(data[52:60]) - if m.recved.Get(crc) != 0 { // 是重放攻击 + if m.recved.Get(crc) { // 是重放攻击 return nil } - logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[10:12])) + logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11])) if flags == 0 || flags == 0x4000 { h := head.SelectPacket() _, err := h.Unmarshal(data) @@ -43,7 +44,7 @@ func (m *Me) wait(data []byte) *head.Packet { logrus.Errorln("[recv] unmarshal err:", err) return nil } - m.recved.Set(crc, 1) + m.recved.Set(crc, true) return h } @@ -56,8 +57,8 @@ func (m *Me) wait(data []byte) *head.Packet { if err == nil { if ok { m.recving.Delete(hsh) - m.recved.Set(crc, 1) - logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "is reached") + m.recved.Set(crc, true) + logrus.Debugln("[recv] all parts of", hex.EncodeToString(hashd), "has reached") return h } } else { diff --git a/gold/link/send.go b/gold/link/send.go index fbebaf3..0366747 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -17,6 +17,7 @@ import ( // WriteAndPut 向 peer 发包并将包放回缓存池 func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { + defer p.Put() teatype := uint8(rand.Intn(16)) sndcnt := atomic.AddUintptr(&l.sendcount, 1) mtu := l.mtu @@ -33,22 +34,21 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { delta = 8 } if len(p.Data) <= delta { - defer p.Put() return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false) } if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta { - return 0, errors.New("drop dont fragmnet big trans packet") + return 0, errors.New("drop don't fragmnet big trans packet") } data := p.Data ttl := p.TTL totl := uint32(len(data)) - i := 0 + pos := 0 packet := head.SelectPacket() *packet = *p - for ; int(totl)-i > delta; i += delta { - logrus.Debugln("[send] split frag [", i, "~", i+delta, "], remain:", int(totl)-i-delta) + for ; int(totl)-pos > delta; pos += delta { + logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", int(totl)-pos-delta) packet.Data = data[:delta] - cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, true) + cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, true) n += cnt if err != nil { return n, err @@ -57,10 +57,12 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { packet.TTL = ttl } packet.Put() - p.Data = data - cnt, err := l.write(p, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, false) - p.Put() - n += cnt + if len(data) > 0 { + p.Data = data + cnt := 0 + cnt, err = l.write(p, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, false) + n += cnt + } return n, err } diff --git a/helper/writer.go b/helper/writer.go index 89016e2..ef0450c 100644 --- a/helper/writer.go +++ b/helper/writer.go @@ -82,7 +82,7 @@ func (w *Writer) WriteUInt64(v uint64) { } func (w *Writer) WriteString(v string) { - w.WriteUInt32(uint32(len(v) + 4)) + //w.WriteUInt32(uint32(len(v) + 4)) (*bytes.Buffer)(w).WriteString(v) } diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index d510e39..65f10d7 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -1,6 +1,7 @@ package tunnel import ( + "encoding/hex" "io" "net" @@ -82,7 +83,13 @@ func (s *Tunnel) Stop() { func (s *Tunnel) handleWrite() { for b := range s.in { - logrus.Debugln("[tunnel] write recv", b) + end := 64 + endl := "..." + if len(b) < 64 { + end = len(b) + endl = "." + } + logrus.Debugln("[tunnel] write send", hex.EncodeToString(b[:end]), endl) if b == nil { logrus.Errorln("[tunnel] write recv nil") break @@ -114,7 +121,13 @@ func (s *Tunnel) handleRead() { logrus.Errorln("[tunnel] read recv nil") break } - logrus.Debugln("[tunnel] read recv", p.Data) + end := 64 + endl := "..." + if len(p.Data) < 64 { + end = len(p.Data) + endl = "." + } + logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Data[:end]), endl) s.out <- p.Data p.Put() } diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 44bea4b..f1553c4 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "io" + "strings" "testing" "time" @@ -11,10 +12,12 @@ import ( "github.com/sirupsen/logrus" "github.com/fumiama/WireGold/gold/link" + "github.com/fumiama/WireGold/helper" ) func TestTunnel(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) + logrus.SetFormatter(&logFormat{enableColor: false}) selfpk, err := curve.New(nil) if err != nil { @@ -31,7 +34,7 @@ func TestTunnel(t *testing.T) { m := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.2/32", - MyEndpoint: "127.0.0.1:21236", + MyEndpoint: "127.0.0.1:21246", PrivateKey: selfpk.Private(), SrcPort: 1, DstPort: 1, @@ -39,14 +42,14 @@ func TestTunnel(t *testing.T) { }) m.AddPeer(&link.PeerConfig{ PeerIP: "192.168.1.3", - EndPoint: "127.0.0.1:21237", + EndPoint: "127.0.0.1:21247", AllowedIPs: []string{"192.168.1.3/32"}, PubicKey: peerpk.Public(), MTU: 4096, }) p := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.3/32", - MyEndpoint: "127.0.0.1:21237", + MyEndpoint: "127.0.0.1:21247", PrivateKey: peerpk.Private(), SrcPort: 1, DstPort: 1, @@ -54,7 +57,7 @@ func TestTunnel(t *testing.T) { }) p.AddPeer(&link.PeerConfig{ PeerIP: "192.168.1.2", - EndPoint: "127.0.0.1:21236", + EndPoint: "127.0.0.1:21246", AllowedIPs: []string{"192.168.1.2/32"}, PubicKey: selfpk.Public(), MTU: 4096, @@ -70,7 +73,7 @@ func TestTunnel(t *testing.T) { } tunnpeer.Start(1, 1, 4096) - time.Sleep(time.Second * 10) // wait link up + time.Sleep(time.Second) // wait link up sendb := ([]byte)("1234") tunnme.Write(sendb) @@ -104,3 +107,64 @@ func TestTunnel(t *testing.T) { tunnme.Stop() tunnpeer.Stop() } + +// logFormat specialize for go-cqhttp +type logFormat struct { + enableColor bool +} + +// Format implements logrus.Formatter +func (f logFormat) Format(entry *logrus.Entry) ([]byte, error) { + buf := helper.SelectWriter() + defer helper.PutWriter(buf) + + buf.WriteByte('[') + if f.enableColor { + buf.WriteString(getLogLevelColorCode(entry.Level)) + } + buf.WriteString(strings.ToUpper(entry.Level.String())) + if f.enableColor { + buf.WriteString(colorReset) + } + buf.WriteString("] ") + buf.WriteString(entry.Message) + buf.WriteString("\n") + + ret := make([]byte, len(buf.Bytes())) + copy(ret, buf.Bytes()) // copy buffer + return ret, nil +} + +const ( + colorCodePanic = "\x1b[1;31m" // color.Style{color.Bold, color.Red}.String() + colorCodeFatal = "\x1b[1;31m" // color.Style{color.Bold, color.Red}.String() + colorCodeError = "\x1b[31m" // color.Style{color.Red}.String() + colorCodeWarn = "\x1b[33m" // color.Style{color.Yellow}.String() + colorCodeInfo = "\x1b[37m" // color.Style{color.White}.String() + colorCodeDebug = "\x1b[32m" // color.Style{color.Green}.String() + colorCodeTrace = "\x1b[36m" // color.Style{color.Cyan}.String() + colorReset = "\x1b[0m" +) + +// getLogLevelColorCode 获取日志等级对应色彩code +func getLogLevelColorCode(level logrus.Level) string { + switch level { + case logrus.PanicLevel: + return colorCodePanic + case logrus.FatalLevel: + return colorCodeFatal + case logrus.ErrorLevel: + return colorCodeError + case logrus.WarnLevel: + return colorCodeWarn + case logrus.InfoLevel: + return colorCodeInfo + case logrus.DebugLevel: + return colorCodeDebug + case logrus.TraceLevel: + return colorCodeTrace + + default: + return colorCodeInfo + } +}