1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-21 11:02:42 +08:00

fix(nat): keep alive blocked by firewall

This commit is contained in:
源文雨
2025-02-22 03:24:44 +09:00
parent 5c65302d67
commit a205d889ca
5 changed files with 66 additions and 26 deletions

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"net" "net"
"sync/atomic" "sync/atomic"
"time"
"github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/gold/p2p" "github.com/fumiama/WireGold/gold/p2p"
@@ -38,8 +39,8 @@ type Link struct {
keys [32]cipher.AEAD keys [32]cipher.AEAD
// 本机信息 // 本机信息
me *Me me *Me
// 连接的状态,详见下方 const // 最后一次收到报文的时间
status int8 lastalive *time.Time
// 是否允许转发 // 是否允许转发
allowtrans bool allowtrans bool
// 是否对数据进行 zstd 压缩 // 是否对数据进行 zstd 压缩
@@ -52,12 +53,6 @@ type Link struct {
mturandomrange uint16 mturandomrange uint16
} }
const (
LINK_STATUS_DOWN = iota
LINK_STATUS_HALFUP
LINK_STATUS_UP
)
// Connect 初始化与 peer 的连接 // Connect 初始化与 peer 的连接
func (m *Me) Connect(peer string) (*Link, error) { func (m *Me) Connect(peer string) (*Link, error) {
p, ok := m.IsInPeer(net.ParseIP(peer).String()) p, ok := m.IsInPeer(net.ParseIP(peer).String())
@@ -69,7 +64,7 @@ func (m *Me) Connect(peer string) (*Link, error) {
// Close 关闭到 peer 的连接 // Close 关闭到 peer 的连接
func (l *Link) Close() { func (l *Link) Close() {
l.status = LINK_STATUS_DOWN l.Destroy()
} }
// Destroy 从 connections 移除 peer // Destroy 从 connections 移除 peer

View File

@@ -8,7 +8,9 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"unsafe"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -180,6 +182,8 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
p.endpoint = addr p.endpoint = addr
} }
} }
now := time.Now()
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.lastalive)), unsafe.Pointer(&now))
switch { switch {
case p.IsToMe(packet.Dst): case p.IsToMe(packet.Dst):
if !p.Accept(packet.Src) { if !p.Accept(packet.Src) {
@@ -222,20 +226,13 @@ func (m *Me) dispatch(packet *head.Packet, addr p2p.EndPoint, index int, finish
} }
switch packet.Proto { switch packet.Proto {
case head.ProtoHello: case head.ProtoHello:
switch p.status { n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
case LINK_STATUS_DOWN: if err == nil {
n, err := p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) if config.ShowDebugLog {
if err == nil { logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet")
if config.ShowDebugLog {
logrus.Debugln("[listen] @", index, "send", n, "bytes hello ack packet")
}
p.status = LINK_STATUS_HALFUP
} else {
logrus.Errorln("[listen] @", index, "send hello ack packet error:", err)
} }
case LINK_STATUS_HALFUP: } else {
p.status = LINK_STATUS_UP logrus.Errorln("[listen] @", index, "send hello ack packet error:", err)
case LINK_STATUS_UP:
} }
packet.Put() packet.Put()
case head.ProtoNotify: case head.ProtoNotify:

View File

@@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"io" "io"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -22,6 +23,8 @@ import (
// Me 是本机的抽象 // Me 是本机的抽象
type Me struct { type Me struct {
// 用于自我重连
cfg *MyConfig
// 本机私钥 // 本机私钥
// 利用 Curve25519 生成 // 利用 Curve25519 生成
// https://pkg.go.dev/golang.org/x/crypto/curve25519 // https://pkg.go.dev/golang.org/x/crypto/curve25519
@@ -77,6 +80,7 @@ type NICConfig struct {
// NewMe 设置本机参数 // NewMe 设置本机参数
func NewMe(cfg *MyConfig) (m Me) { func NewMe(cfg *MyConfig) (m Me) {
m.cfg = cfg
m.privKey = *cfg.PrivateKey m.privKey = *cfg.PrivateKey
var err error var err error
nw := cfg.Network nw := cfg.Network
@@ -128,6 +132,33 @@ func NewMe(cfg *MyConfig) (m Me) {
return return
} }
// Restart 重新连接
func (m *Me) Restart() error {
oldconn := m.conn
m.conn = nil
if !reflect.ValueOf(oldconn).IsZero() {
_ = oldconn.Close()
}
var err error
nw := m.cfg.Network
if nw == "" {
nw = "udp"
}
m.networkconfigs = m.cfg.NetworkConfigs
m.ep, err = p2p.NewEndPoint(nw, m.cfg.MyEndpoint, m.networkconfigs...)
if err != nil {
return err
}
ip, cidr, err := net.ParseCIDR(m.cfg.MyIPwithMask)
if err != nil {
return err
}
m.me = ip
m.subnet = *cidr
m.conn, err = m.listen()
return err
}
func (m *Me) SrcPort() uint16 { func (m *Me) SrcPort() uint16 {
return m.srcport return m.srcport
} }
@@ -146,7 +177,7 @@ func (m *Me) EndPoint() p2p.EndPoint {
func (m *Me) Close() error { func (m *Me) Close() error {
m.connections = nil m.connections = nil
if m.conn != nil { if !reflect.ValueOf(m.conn).IsZero() {
_ = m.conn.Close() _ = m.conn.Close()
m.conn = nil m.conn = nil
} }

View File

@@ -2,7 +2,9 @@ package link
import ( import (
"encoding/json" "encoding/json"
"sync/atomic"
"time" "time"
"unsafe"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -23,11 +25,21 @@ func (l *Link) keepAlive(dur int64) {
if l.me.connections == nil { if l.me.connections == nil {
return return
} }
la := (*time.Time)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&l.lastalive))))
if la != nil && time.Since(*la) > 10*time.Second*time.Duration(dur) { // 可能已经被阻断, 断开重连
logrus.Warnln("[nat] no response after 10 keep alive tries, re-connecting...")
err := l.me.Restart()
if err != nil {
logrus.Errorln("[nat] re-connect me err:", err)
} else {
logrus.Infoln("[nat] re-connect me succeeded")
}
}
n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false)
if err == nil { if err == nil {
logrus.Infoln("[nat] send", n, "bytes keep alive packet") logrus.Infoln("[nat] send", n, "bytes keep alive packet")
} else { } else {
logrus.Errorln("[nat] send keep alive packet error:", err) logrus.Warnln("[nat] send keep alive packet error:", err)
} }
} }
} }

View File

@@ -143,8 +143,13 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas
bound = len(d) bound = len(d)
endl = "." endl = "."
} }
conn := l.me.conn
if conn == nil {
return 0, io.ErrClosedPipe
}
if config.ShowDebugLog { if config.ShowDebugLog {
logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64()))
logrus.Debugln("[send] write", len(d), "bytes data from ep", conn.LocalAddr(), "to", peerep, "offset", fmt.Sprintf("%04x", offset), "crc", fmt.Sprintf("%016x", p.CRC64()))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
} }
d = l.me.xorenc(d, seq) d = l.me.xorenc(d, seq)
@@ -155,5 +160,5 @@ func (l *Link) writeonce(p *head.Packet, teatype uint8, additional uint16, datas
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl) logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
} }
defer helper.PutBytes(d) defer helper.PutBytes(d)
return l.me.conn.WriteToPeer(d, peerep) return conn.WriteToPeer(d, peerep)
} }