1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-04 23:40:26 +08:00

feat(tcp): validate conn on accept

This commit is contained in:
源文雨
2024-07-17 00:39:29 +09:00
parent 8fa23be251
commit 4ffacafb23
2 changed files with 67 additions and 15 deletions

View File

@@ -5,8 +5,10 @@ import (
"errors"
"io"
"net"
"time"
"github.com/fumiama/WireGold/helper"
"github.com/sirupsen/logrus"
)
var (
@@ -86,3 +88,39 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
defer cl()
return io.Copy(w, &buf)
}
func isvalid(tcpconn *net.TCPConn) bool {
pckt := packet{}
stopch := make(chan struct{})
t := time.AfterFunc(time.Second, func() {
stopch <- struct{}{}
})
var err error
copych := make(chan struct{})
go func() {
_, err = io.Copy(&pckt, tcpconn)
copych <- struct{}{}
}()
select {
case <-stopch:
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
return false
case <-copych:
t.Stop()
}
if err != nil {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
return false
}
if pckt.typ != packetTypeKeepAlive {
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
return false
}
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
return true
}

View File

@@ -50,8 +50,8 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) {
}
ep.addr = lstn.Addr().(*net.TCPAddr)
peerstimeout := ep.peerstimeout
if peerstimeout < time.Second {
peerstimeout = time.Second * 5
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
chansz := ep.recvchansize
if chansz < 32 {
@@ -112,21 +112,28 @@ func (conn *Conn) accept() {
logrus.Info("[tcp] re-listen on", conn.addr)
continue
}
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})
logrus.Debugln("[tcp] accept from", ep)
conn.peers.Set(ep.String(), tcpconn)
go conn.receive(ep)
go conn.receive(tcpconn, false)
}
}
func (conn *Conn) receive(ep *EndPoint) {
peerstimeout := ep.peerstimeout
if peerstimeout < time.Second {
peerstimeout = time.Second * 5
func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
DialTimeout: conn.addr.dialtimeout,
PeersTimeout: conn.addr.peerstimeout,
ReceiveChannelSize: conn.addr.recvchansize,
})
if !hasvalidated {
if !isvalid(tcpconn) {
return
}
logrus.Debugln("[tcp] accept from", ep)
conn.peers.Set(ep.String(), tcpconn)
}
peerstimeout := conn.addr.peerstimeout
if peerstimeout < time.Second*30 {
peerstimeout = time.Second * 30
}
peerstimeout *= 2
for {
@@ -244,9 +251,16 @@ func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
if !ok {
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
}
_, err = io.Copy(tcpconn, &packet{
typ: packetTypeKeepAlive,
})
if err != nil {
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err)
return 0, err
}
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
conn.peers.Set(tcpep.String(), tcpconn)
go conn.receive(tcpep)
go conn.receive(tcpconn, true)
} else {
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
}