mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-04 23:40:26 +08:00
feat(tcp): validate conn on accept
This commit is contained in:
@@ -5,8 +5,10 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -86,3 +88,39 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
||||
defer cl()
|
||||
return io.Copy(w, &buf)
|
||||
}
|
||||
|
||||
func isvalid(tcpconn *net.TCPConn) bool {
|
||||
pckt := packet{}
|
||||
|
||||
stopch := make(chan struct{})
|
||||
t := time.AfterFunc(time.Second, func() {
|
||||
stopch <- struct{}{}
|
||||
})
|
||||
|
||||
var err error
|
||||
copych := make(chan struct{})
|
||||
go func() {
|
||||
_, err = io.Copy(&pckt, tcpconn)
|
||||
copych <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopch:
|
||||
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
|
||||
return false
|
||||
case <-copych:
|
||||
t.Stop()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
|
||||
return false
|
||||
}
|
||||
if pckt.typ != packetTypeKeepAlive {
|
||||
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -50,8 +50,8 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) {
|
||||
}
|
||||
ep.addr = lstn.Addr().(*net.TCPAddr)
|
||||
peerstimeout := ep.peerstimeout
|
||||
if peerstimeout < time.Second {
|
||||
peerstimeout = time.Second * 5
|
||||
if peerstimeout < time.Second*30 {
|
||||
peerstimeout = time.Second * 30
|
||||
}
|
||||
chansz := ep.recvchansize
|
||||
if chansz < 32 {
|
||||
@@ -112,21 +112,28 @@ func (conn *Conn) accept() {
|
||||
logrus.Info("[tcp] re-listen on", conn.addr)
|
||||
continue
|
||||
}
|
||||
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||
DialTimeout: conn.addr.dialtimeout,
|
||||
PeersTimeout: conn.addr.peerstimeout,
|
||||
ReceiveChannelSize: conn.addr.recvchansize,
|
||||
})
|
||||
logrus.Debugln("[tcp] accept from", ep)
|
||||
conn.peers.Set(ep.String(), tcpconn)
|
||||
go conn.receive(ep)
|
||||
go conn.receive(tcpconn, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) receive(ep *EndPoint) {
|
||||
peerstimeout := ep.peerstimeout
|
||||
if peerstimeout < time.Second {
|
||||
peerstimeout = time.Second * 5
|
||||
func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
|
||||
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||
DialTimeout: conn.addr.dialtimeout,
|
||||
PeersTimeout: conn.addr.peerstimeout,
|
||||
ReceiveChannelSize: conn.addr.recvchansize,
|
||||
})
|
||||
|
||||
if !hasvalidated {
|
||||
if !isvalid(tcpconn) {
|
||||
return
|
||||
}
|
||||
logrus.Debugln("[tcp] accept from", ep)
|
||||
conn.peers.Set(ep.String(), tcpconn)
|
||||
}
|
||||
|
||||
peerstimeout := conn.addr.peerstimeout
|
||||
if peerstimeout < time.Second*30 {
|
||||
peerstimeout = time.Second * 30
|
||||
}
|
||||
peerstimeout *= 2
|
||||
for {
|
||||
@@ -244,9 +251,16 @@ func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
|
||||
if !ok {
|
||||
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
|
||||
}
|
||||
_, err = io.Copy(tcpconn, &packet{
|
||||
typ: packetTypeKeepAlive,
|
||||
})
|
||||
if err != nil {
|
||||
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err)
|
||||
return 0, err
|
||||
}
|
||||
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
|
||||
conn.peers.Set(tcpep.String(), tcpconn)
|
||||
go conn.receive(tcpep)
|
||||
go conn.receive(tcpconn, true)
|
||||
} else {
|
||||
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user