From a205d889ca815c6c267d64850007d54e034c32f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 22 Feb 2025 03:24:44 +0900 Subject: [PATCH] fix(nat): keep alive blocked by firewall --- gold/link/link.go | 13 ++++--------- gold/link/listen.go | 23 ++++++++++------------- gold/link/me.go | 33 ++++++++++++++++++++++++++++++++- gold/link/nat.go | 14 +++++++++++++- gold/link/send.go | 9 +++++++-- 5 files changed, 66 insertions(+), 26 deletions(-) diff --git a/gold/link/link.go b/gold/link/link.go index 3692562..8e76748 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -5,6 +5,7 @@ import ( "errors" "net" "sync/atomic" + "time" "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/p2p" @@ -38,8 +39,8 @@ type Link struct { keys [32]cipher.AEAD // 本机信息 me *Me - // 连接的状态,详见下方 const - status int8 + // 最后一次收到报文的时间 + lastalive *time.Time // 是否允许转发 allowtrans bool // 是否对数据进行 zstd 压缩 @@ -52,12 +53,6 @@ type Link struct { mturandomrange uint16 } -const ( - LINK_STATUS_DOWN = iota - LINK_STATUS_HALFUP - LINK_STATUS_UP -) - // Connect 初始化与 peer 的连接 func (m *Me) Connect(peer string) (*Link, error) { p, ok := m.IsInPeer(net.ParseIP(peer).String()) @@ -69,7 +64,7 @@ func (m *Me) Connect(peer string) (*Link, error) { // Close 关闭到 peer 的连接 func (l *Link) Close() { - l.status = LINK_STATUS_DOWN + l.Destroy() } // Destroy 从 connections 移除 peer diff --git a/gold/link/listen.go b/gold/link/listen.go index 7b4f21e..7180483 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -8,7 +8,9 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "time" + "unsafe" "github.com/klauspost/compress/zstd" "github.com/sirupsen/logrus" @@ -180,6 +182,8 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish p.endpoint = addr } } + now := time.Now() + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.lastalive)), unsafe.Pointer(&now)) switch { case p.IsToMe(packet.Dst): if !p.Accept(packet.Src) { @@ -222,20 +226,13 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish } switch packet.Proto { case head.ProtoHello: - switch p.status { - case LINK_STATUS_DOWN: - n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) - if err == nil { - if config.ShowDebugLog { - logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet") - } - p.status = LINK_STATUS_HALFUP - } else { - logrus.Errorln("[listen] @", index, "send hello ack packet error:", err) + n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) + if err == nil { + if config.ShowDebugLog { + logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet") } - case LINK_STATUS_HALFUP: - p.status = LINK_STATUS_UP - case LINK_STATUS_UP: + } else { + logrus.Errorln("[listen] @", index, "send hello ack packet error:", err) } packet.Put() case head.ProtoNotify: diff --git a/gold/link/me.go b/gold/link/me.go index 275f92e..5782014 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "io" "net" + "reflect" "strconv" "sync" "time" @@ -22,6 +23,8 @@ import ( // Me 是本机的抽象 type Me struct { + // 用于自我重连 + cfg *MyConfig // 本机私钥 // 利用 Curve25519 生成 // https://pkg.go.dev/golang.org/x/crypto/curve25519 @@ -77,6 +80,7 @@ type NICConfig struct { // NewMe 设置本机参数 func NewMe(cfg *MyConfig) (m Me) { + m.cfg = cfg m.privKey = *cfg.PrivateKey var err error nw := cfg.Network @@ -128,6 +132,33 @@ func NewMe(cfg *MyConfig) (m Me) { return } +// Restart 重新连接 +func (m *Me) Restart() error { + oldconn := m.conn + m.conn = nil + if !reflect.ValueOf(oldconn).IsZero() { + _ = oldconn.Close() + } + var err error + nw := m.cfg.Network + if nw == "" { + nw = "udp" + } + m.networkconfigs = m.cfg.NetworkConfigs + m.ep, err = p2p.NewEndPoint(nw, m.cfg.MyEndpoint, m.networkconfigs...) + if err != nil { + return err + } + ip, cidr, err := net.ParseCIDR(m.cfg.MyIPwithMask) + if err != nil { + return err + } + m.me = ip + m.subnet = *cidr + m.conn, err = m.listen() + return err +} + func (m *Me) SrcPort() uint16 { return m.srcport } @@ -146,7 +177,7 @@ func (m *Me) EndPoint() p2p.EndPoint { func (m *Me) Close() error { m.connections = nil - if m.conn != nil { + if !reflect.ValueOf(m.conn).IsZero() { _ = m.conn.Close() m.conn = nil } diff --git a/gold/link/nat.go b/gold/link/nat.go index cacb3ca..a69e422 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -2,7 +2,9 @@ package link import ( "encoding/json" + "sync/atomic" "time" + "unsafe" "github.com/sirupsen/logrus" @@ -23,11 +25,21 @@ func (l *Link) keepAlive(dur int64) { if l.me.connections == nil { return } + la := (*time.Time)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&l.lastalive)))) + if la != nil && time.Since(*la) > 10*time.Second*time.Duration(dur) { // 可能已经被阻断, 断开重连 + logrus.Warnln("[nat] no response after 10 keep alive tries, re-connecting...") + err := l.me.Restart() + if err != nil { + logrus.Errorln("[nat] re-connect me err:", err) + } else { + logrus.Infoln("[nat] re-connect me succeeded") + } + } n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) if err == nil { logrus.Infoln("[nat] send", n, "bytes keep alive packet") } else { - logrus.Errorln("[nat] send keep alive packet error:", err) + logrus.Warnln("[nat] send keep alive packet error:", err) } } } diff --git a/gold/link/send.go b/gold/link/send.go index 9d0d003..d1a411f 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -143,8 +143,13 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas bound = len(d) endl = "." } + conn := l.me.conn + if conn == nil { + return 0, io.ErrClosedPipe + } if config.ShowDebugLog { - logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64())) + + logrus.Debugln("[send] write", len(d), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64())) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) } d = l.me.xorenc(d, seq) @@ -155,5 +160,5 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) } defer helper.PutBytes(d) - return l.me.conn.WriteToPeer(d, peerep) + return conn.WriteToPeer(d, peerep) }