diff --git a/gold/p2p/tcp/pdu.go b/gold/p2p/tcp/pdu.go index 25f5488..fa4d23f 100644 --- a/gold/p2p/tcp/pdu.go +++ b/gold/p2p/tcp/pdu.go @@ -2,19 +2,33 @@ package tcp import ( "encoding/binary" + "errors" "io" "net" "github.com/fumiama/WireGold/helper" ) +var ( + ErrInvalidMagic = errors.New("invalid magic") +) + type packetType uint8 const ( packetTypeKeepAlive packetType = iota packetTypeNormal + packetTypeTop ) +const magic = 0x12d3fde9 + +var magicbuf [4]byte + +func init() { + binary.LittleEndian.PutUint32(magicbuf[:], magic) +} + type packet struct { typ packetType len uint16 @@ -28,7 +42,7 @@ func (p *packet) pack() (net.Buffers, func()) { w.WriteByte(byte(p.typ)) w.WriteUInt16(p.len) }) - return net.Buffers{d, p.dat}, cl + return net.Buffers{magicbuf[:], d, p.dat}, cl } func (p *packet) Read(_ []byte) (int, error) { @@ -40,12 +54,21 @@ func (p *packet) Write(_ []byte) (int, error) { } func (p *packet) ReadFrom(r io.Reader) (n int64, err error) { - var buf [3]byte + var buf [4]byte cnt, err := io.ReadFull(r, buf[:]) n = int64(cnt) if err != nil { return } + if binary.LittleEndian.Uint32(buf[:]) != magic { + err = ErrInvalidMagic + return + } + cnt, err = io.ReadFull(r, buf[:3]) + n += int64(cnt) + if err != nil { + return + } p.typ = packetType(buf[0]) p.len = binary.LittleEndian.Uint16(buf[1:3]) w := helper.SelectWriter() diff --git a/gold/p2p/tcp/tcp.go b/gold/p2p/tcp/tcp.go index 22f7c54..08f0691 100644 --- a/gold/p2p/tcp/tcp.go +++ b/gold/p2p/tcp/tcp.go @@ -163,6 +163,12 @@ func (conn *Conn) receive(ep *EndPoint) { if err != nil { logrus.Debugln("[tcp] recv from", ep, "err:", err) + _ = tcpconn.CloseRead() + return + } + if r.pckt.typ >= packetTypeTop { + logrus.Debugln("[tcp] close reading invalid conn from", ep, "typ", r.pckt.typ, "len", r.pckt.len) + _ = tcpconn.CloseRead() return } logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)