mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-12 12:50:28 +08:00
feat(tcp): validate conn on accept
This commit is contained in:
@@ -5,8 +5,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/helper"
|
"github.com/fumiama/WireGold/helper"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -86,3 +88,39 @@ func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
|||||||
defer cl()
|
defer cl()
|
||||||
return io.Copy(w, &buf)
|
return io.Copy(w, &buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isvalid(tcpconn *net.TCPConn) bool {
|
||||||
|
pckt := packet{}
|
||||||
|
|
||||||
|
stopch := make(chan struct{})
|
||||||
|
t := time.AfterFunc(time.Second, func() {
|
||||||
|
stopch <- struct{}{}
|
||||||
|
})
|
||||||
|
|
||||||
|
var err error
|
||||||
|
copych := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_, err = io.Copy(&pckt, tcpconn)
|
||||||
|
copych <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stopch:
|
||||||
|
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
|
||||||
|
return false
|
||||||
|
case <-copych:
|
||||||
|
t.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pckt.typ != packetTypeKeepAlive {
|
||||||
|
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ func (ep *EndPoint) Listen() (p2p.Conn, error) {
|
|||||||
}
|
}
|
||||||
ep.addr = lstn.Addr().(*net.TCPAddr)
|
ep.addr = lstn.Addr().(*net.TCPAddr)
|
||||||
peerstimeout := ep.peerstimeout
|
peerstimeout := ep.peerstimeout
|
||||||
if peerstimeout < time.Second {
|
if peerstimeout < time.Second*30 {
|
||||||
peerstimeout = time.Second * 5
|
peerstimeout = time.Second * 30
|
||||||
}
|
}
|
||||||
chansz := ep.recvchansize
|
chansz := ep.recvchansize
|
||||||
if chansz < 32 {
|
if chansz < 32 {
|
||||||
@@ -112,21 +112,28 @@ func (conn *Conn) accept() {
|
|||||||
logrus.Info("[tcp] re-listen on", conn.addr)
|
logrus.Info("[tcp] re-listen on", conn.addr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
go conn.receive(tcpconn, false)
|
||||||
DialTimeout: conn.addr.dialtimeout,
|
|
||||||
PeersTimeout: conn.addr.peerstimeout,
|
|
||||||
ReceiveChannelSize: conn.addr.recvchansize,
|
|
||||||
})
|
|
||||||
logrus.Debugln("[tcp] accept from", ep)
|
|
||||||
conn.peers.Set(ep.String(), tcpconn)
|
|
||||||
go conn.receive(ep)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) receive(ep *EndPoint) {
|
func (conn *Conn) receive(tcpconn *net.TCPConn, hasvalidated bool) {
|
||||||
peerstimeout := ep.peerstimeout
|
ep, _ := newEndpoint(tcpconn.RemoteAddr().String(), &Config{
|
||||||
if peerstimeout < time.Second {
|
DialTimeout: conn.addr.dialtimeout,
|
||||||
peerstimeout = time.Second * 5
|
PeersTimeout: conn.addr.peerstimeout,
|
||||||
|
ReceiveChannelSize: conn.addr.recvchansize,
|
||||||
|
})
|
||||||
|
|
||||||
|
if !hasvalidated {
|
||||||
|
if !isvalid(tcpconn) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.Debugln("[tcp] accept from", ep)
|
||||||
|
conn.peers.Set(ep.String(), tcpconn)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerstimeout := conn.addr.peerstimeout
|
||||||
|
if peerstimeout < time.Second*30 {
|
||||||
|
peerstimeout = time.Second * 30
|
||||||
}
|
}
|
||||||
peerstimeout *= 2
|
peerstimeout *= 2
|
||||||
for {
|
for {
|
||||||
@@ -244,9 +251,16 @@ func (conn *Conn) WriteToPeer(b []byte, ep p2p.EndPoint) (n int, err error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
|
return 0, errors.New("expect *net.TCPConn but got " + reflect.ValueOf(cn).Type().String())
|
||||||
}
|
}
|
||||||
|
_, err = io.Copy(tcpconn, &packet{
|
||||||
|
typ: packetTypeKeepAlive,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, but write err:", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
|
logrus.Debugln("[tcp] dial to", tcpep.addr, "success, local:", tcpconn.LocalAddr())
|
||||||
conn.peers.Set(tcpep.String(), tcpconn)
|
conn.peers.Set(tcpep.String(), tcpconn)
|
||||||
go conn.receive(tcpep)
|
go conn.receive(tcpconn, true)
|
||||||
} else {
|
} else {
|
||||||
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
|
logrus.Debugln("[tcp] reuse tcpconn from", tcpconn.LocalAddr(), "to", tcpconn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user