1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-04 23:40:26 +08:00
Files
WireGold/gold/p2p/tcp/tcp.go
2026-04-11 15:02:45 +08:00

509 lines
11 KiB
Go

package tcp
import (
"errors"
"io"
"net"
"reflect"
"strconv"
"sync"
"time"
"github.com/FloatTech/ttl"
"github.com/sirupsen/logrus"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/gold/p2p"
)
type EndPoint struct {
addr *net.TCPAddr
dialtimeout time.Duration
peerstimeout time.Duration
keepinterval time.Duration
recvchansize int
}
func (ep *EndPoint) String() string {
return ep.addr.String()
}
func (ep *EndPoint) Network() string {
return ep.addr.Network()
}
func (ep *EndPoint) Equal(ep2 p2p.EndPoint) bool {
if ep == nil || ep2 == nil {
return ep == nil && ep2 == nil
}
tcpep2, ok := ep2.(*EndPoint)
if !ok {
return false
}
tcpep1 := ep
return tcpep1.addr.IP.Equal(tcpep2.addr.IP) &&
tcpep1.addr.Port == tcpep2.addr.Port &&
tcpep1.addr.Zone == tcpep2.addr.Zone
}
func (ep *EndPoint) Listen() (p2p.Conn, error) {
lstn, err := net.ListenTCP(ep.addr.Network(), ep.addr)
if err != nil {
return nil, err
}
ep.addr = lstn.Addr().(*net.TCPAddr)
peerstimeout := ep.peerstimeout
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
chansz := ep.recvchansize
if chansz < 32 {
chansz = 32
}
conn := &Conn{
addr: ep,
lstn: lstn,
peers: ttl.NewCacheOn(peerstimeout, [4]func(string, *net.TCPConn){
func(_ string, t *net.TCPConn) {
_ = t.SetLinger(0)
_ = t.SetNoDelay(true)
}, nil, func(_ string, t *net.TCPConn) {
err := t.CloseWrite()
if config.ShowDebugLog {
if err != nil {
logrus.Debugln("[tcp] close write from", t.LocalAddr(), "to", t.RemoteAddr(), "err:", err)
} else {
logrus.Debugln("[tcp] close write from", t.LocalAddr(), "to", t.RemoteAddr())
}
}
}, nil,
}),
recv: make(chan *connrecv, chansz),
cplk: &sync.Mutex{},
sblk: &sync.RWMutex{},
}
go conn.accept()
return conn, nil
}
type connrecv struct {
addr *EndPoint // cast from tcpconn.RemoteAddr()
conn *net.TCPConn
pckt packet
}
type subconn struct {
cplk sync.Mutex
last time.Time // last active time
conn *net.TCPConn
}
// Conn 伪装成无状态的有状态连接
type Conn struct {
addr *EndPoint
lstn *net.TCPListener
peers *ttl.Cache[string, *net.TCPConn]
recv chan *connrecv
cplk *sync.Mutex
sblk *sync.RWMutex
subs []*subconn
suberr bool
}
func (conn *Conn) accept() {
for {
tcpconn, err := conn.lstn.AcceptTCP()
if err != nil {
if errors.Is(err, net.ErrClosed) { // normal close
logrus.Infoln("[tcp] accept of", conn.addr, "got closed")
return
}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
logrus.Warnln("[tcp] accept on", conn.addr, "err:", err)
_ = conn.Close()
newc, err := conn.addr.Listen()
if err != nil {
logrus.Warnln("[tcp] re-listen on", conn.addr, "err:", err)
return
}
*conn = *newc.(*Conn)
logrus.Infoln("[tcp] re-listen on", conn.addr)
continue
}
go conn.receive(tcpconn, false)
}
}
func delsubs(i int, subs []*subconn) []*subconn {
tcpconn := subs[i].conn
err := tcpconn.CloseWrite()
if config.ShowDebugLog {
if err != nil {
logrus.Debugln("[tcp] close sub write from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr(), "err:", err)
} else {
logrus.Debugln("[tcp] close sub write from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
}
}
switch i {
case 0:
subs = subs[1:]
case len(subs) - 1:
subs = subs[:len(subs)-1]
default:
subs = append(subs[:i], subs[i+1:]...)
}
return subs
}
func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
if conn.peers == nil {
return
}
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout,
KeepInterval: conn.addr.keepinterval,
ReceiveChannelSize: conn.addr.recvchansize,
})
issub, ok := false, false
peerstimeout := conn.addr.peerstimeout
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
peerstimeout *= 2
if !hasvalidated {
issub, ok = isvalid(tcpconn, peerstimeout)
if !ok {
return
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] accept from", ep, "issub:", issub)
}
if issub {
conn.sblk.Lock()
conn.subs = append(conn.subs, &subconn{conn: tcpconn, last: time.Now()})
conn.sblk.Unlock()
} else {
if conn.peers == nil {
return
}
conn.peers.Set(ep.String(), tcpconn)
}
}
if issub {
defer func() {
conn.sblk.Lock()
subs := conn.subs
for i, sub := range subs {
if sub.conn == tcpconn {
conn.subs = delsubs(i, conn.subs)
break
}
}
conn.sblk.Unlock()
}()
} else {
if conn.peers == nil {
return
}
defer conn.peers.Delete(ep.String())
}
go conn.keep(ep)
for {
r := &connrecv{addr: ep}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
if !issub {
tcpconn = conn.peers.Get(ep.String())
if tcpconn == nil {
return
}
}
r.conn = tcpconn
t := time.NewTimer(peerstimeout)
var err error
copych := make(chan struct{})
go func() {
_, err = io.Copy(&r.pckt, tcpconn)
copych <- struct{}{}
}()
select {
case <-t.C:
if config.ShowDebugLog {
logrus.Debugln("[tcp] recv from", ep, "timeout")
}
_ = tcpconn.CloseRead()
return
case <-copych:
t.Stop()
}
if conn.addr == nil || conn.lstn == nil || conn.peers == nil || conn.recv == nil {
return
}
if err != nil {
if config.ShowDebugLog {
logrus.Debugln("[tcp] recv from", ep, "err:", err)
}
if errors.Is(err, net.ErrClosed) ||
errors.Is(err, io.ErrClosedPipe) ||
errors.Is(err, io.EOF) ||
errors.Is(err, ErrInvalidMagic) {
_ = tcpconn.CloseRead()
return
}
continue
}
if r.pckt.typ >= packetTypeTop {
if config.ShowDebugLog {
logrus.Debugln("[tcp] close reading invalid conn from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
}
_ = tcpconn.CloseRead()
return
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
}
conn.recv <- r
}
}
func (conn *Conn) keep(ep *EndPoint) {
keepinterval := ep.keepinterval
if keepinterval < time.Second*10 {
keepinterval = time.Second * 10
}
t := time.NewTicker(keepinterval)
defer t.Stop()
for range t.C {
if conn.addr == nil || conn.peers == nil {
return
}
tcpconn := conn.peers.Get(ep.String())
if tcpconn != nil {
_, err := io.Copy(tcpconn, &packet{typ: packetTypeKeepAlive})
if conn.addr == nil {
return
}
if err != nil {
logrus.Warnln("[tcp] keep main conn alive from", conn, "to", ep, "err:", err)
conn.peers.Delete(ep.String())
} else if config.ShowDebugLog {
logrus.Debugln("[tcp] keep main conn alive from", conn, "to", ep)
}
}
conn.sblk.RLock()
subs := conn.subs
for i, sub := range subs {
if time.Since(sub.last) < keepinterval {
if config.ShowDebugLog {
logrus.Debugln("[tcp] skip to keep busy sub conn from", conn, "to", ep)
}
continue
}
_, err := io.Copy(sub.conn, &packet{typ: packetTypeSubKeepAlive})
if conn.addr == nil {
return
}
if err != nil {
logrus.Warnln("[tcp] keep sub conn alive from", conn, "to", sub.conn.RemoteAddr(), "err:", err)
conn.subs = delsubs(i, conn.subs) // del 1 link at once
break
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] keep sub conn alive from", conn, "to", ep)
}
}
conn.sblk.RUnlock()
}
}
func (conn *Conn) Close() error {
lstn := conn.lstn
peers := conn.peers
recv := conn.recv
conn.addr = nil
conn.lstn = nil
conn.peers = nil
conn.recv = nil
if lstn != nil {
_ = lstn.Close()
}
if peers != nil {
peers.Destroy()
}
if recv != nil {
close(recv)
}
return nil
}
func (conn *Conn) String() string {
return conn.addr.String()
}
func (conn *Conn) LocalAddr() p2p.EndPoint {
return conn.addr
}
func (conn *Conn) ReadFromPeer(b []byte) (int, p2p.EndPoint, error) {
var p *connrecv
for {
p = <-conn.recv
if p == nil {
return 0, nil, net.ErrClosed
}
if conn.peers == nil {
return 0, nil, net.ErrClosed
}
conn.peers.Set(p.addr.String(), p.conn)
if p.pckt.typ == packetTypeNormal {
break
}
}
n := copy(b, p.pckt.dat)
return n, p.addr, nil
}
// writeToPeer after acquiring lock
func (conn *Conn) writeToPeer(b []byte, tcpep *EndPoint, issub bool) (n int, err error) {
retried := false
ok := false
var (
tcpconn *net.TCPConn
subc *subconn
)
RECONNECT:
if issub {
conn.sblk.RLock()
for _, sub := range conn.subs {
if sub.cplk.TryLock() {
tcpconn = sub.conn
subc = sub
break
}
}
conn.sblk.RUnlock()
} else {
tcpconn = conn.peers.Get(tcpep.String())
}
if tcpconn == nil {
dialtimeout := tcpep.dialtimeout
if dialtimeout < time.Second {
dialtimeout = time.Second
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] dial to", tcpep.addr, "timeout", dialtimeout, "issub", issub)
}
var cn net.Conn
// must use another port to send because there's no exsiting conn
cn, err = net.DialTimeout(tcpep.Network(), tcpep.addr.String(), dialtimeout)
if err != nil {
return
}
tcpconn, ok = cn.(*net.TCPConn)
if !ok {
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
}
pkt := &packet{}
if issub {
pkt.typ = packetTypeSubKeepAlive
} else {
pkt.typ = packetTypeKeepAlive
}
_, err = io.Copy(tcpconn, pkt)
if err != nil {
if config.ShowDebugLog {
logrus.Debugln("[tcp] dial to", tcpep.addr, "issub", issub, "success, but write err:", err)
}
return 0, err
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr(), "issub", issub)
}
if !issub {
conn.peers.Set(tcpep.String(), tcpconn)
} else {
conn.sblk.Lock()
conn.subs = append(conn.subs, &subconn{conn: tcpconn, last: time.Now()})
conn.sblk.Unlock()
}
go conn.receive(tcpconn, true)
} else if config.ShowDebugLog {
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
}
cnt, err := io.Copy(tcpconn, &packet{
typ: packetTypeNormal,
len: uint16(len(b)),
dat: b,
})
if err != nil {
if subc == nil {
conn.peers.Delete(tcpep.String())
} else {
conn.sblk.Lock()
subs := conn.subs
for i, sub := range subs {
if sub == subc {
conn.subs = delsubs(i, conn.subs)
break
}
}
conn.sblk.Unlock()
}
if !retried {
logrus.Warnln("[tcp] reconnect due to write to", tcpconn.RemoteAddr(), "err:", err)
retried = true
tcpconn = nil
goto RECONNECT
}
}
if subc != nil {
subc.last = time.Now()
subc.cplk.Unlock()
}
return int(cnt) - 3, err
}
func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
tcpep, ok := ep.(*EndPoint)
if !ok {
return 0, p2p.ErrEndpointTypeMistatch
}
if len(b) >= 65536 {
return 0, errors.New("data size " + strconv.Itoa(len(b)) + " is too large")
}
locked := conn.cplk.TryLock()
if !locked {
if !conn.suberr || len(conn.subs) > 0 {
if config.ShowDebugLog {
logrus.Debug("[tcp] try sub write to", tcpep)
}
n, err = conn.writeToPeer(b, tcpep, true) // try sub write
if err == nil {
return
}
conn.suberr = true // fast fail
}
conn.cplk.Lock() // add to main queue
}
defer conn.cplk.Unlock()
return conn.writeToPeer(b, tcpep, false)
}