1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-10 19:50:30 +08:00

fix(tcp): close invalid conn

This commit is contained in:
源文雨
2024-07-17 00:23:38 +09:00
parent 1c665c68fb
commit 8fa23be251
2 changed files with 31 additions and 2 deletions

View File

@@ -2,19 +2,33 @@ package tcp
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"github.com/fumiama/WireGold/helper" "github.com/fumiama/WireGold/helper"
) )
var (
ErrInvalidMagic = errors.New("invalid magic")
)
type packetType uint8 type packetType uint8
const ( const (
packetTypeKeepAlive packetType = iota packetTypeKeepAlive packetType = iota
packetTypeNormal packetTypeNormal
packetTypeTop
) )
const magic = 0x12d3fde9
var magicbuf [4]byte
func init() {
binary.LittleEndian.PutUint32(magicbuf[:], magic)
}
type packet struct { type packet struct {
typ packetType typ packetType
len uint16 len uint16
@@ -28,7 +42,7 @@ func (p *packet) pack() (net.Buffers, func()) {
w.WriteByte(byte(p.typ)) w.WriteByte(byte(p.typ))
w.WriteUInt16(p.len) w.WriteUInt16(p.len)
}) })
return net.Buffers{d, p.dat}, cl return net.Buffers{magicbuf[:], d, p.dat}, cl
} }
func (p *packet) Read(_ []byte) (int, error) { func (p *packet) Read(_ []byte) (int, error) {
@@ -40,12 +54,21 @@ func (p *packet) Write(_ []byte) (int, error) {
} }
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) { func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
var buf [3]byte var buf [4]byte
cnt, err := io.ReadFull(r, buf[:]) cnt, err := io.ReadFull(r, buf[:])
n = int64(cnt) n = int64(cnt)
if err != nil { if err != nil {
return return
} }
if binary.LittleEndian.Uint32(buf[:]) != magic {
err = ErrInvalidMagic
return
}
cnt, err = io.ReadFull(r, buf[:3])
n += int64(cnt)
if err != nil {
return
}
p.typ = packetType(buf[0]) p.typ = packetType(buf[0])
p.len = binary.LittleEndian.Uint16(buf[1:3]) p.len = binary.LittleEndian.Uint16(buf[1:3])
w := helper.SelectWriter() w := helper.SelectWriter()

View File

@@ -163,6 +163,12 @@ func (conn *Conn) receive(ep *EndPoint) {
if err != nil { if err != nil {
logrus.Debugln("[tcp] recv from", ep, "err:", err) logrus.Debugln("[tcp] recv from", ep, "err:", err)
_ = tcpconn.CloseRead()
return
}
if r.pckt.typ >= packetTypeTop {
logrus.Debugln("[tcp] close reading invalid conn from", ep, "typ", r.pckt.typ, "len", r.pckt.len)
_ = tcpconn.CloseRead()
return return
} }
logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len) logrus.Debugln("[tcp] dispatch packet from", ep, "typ", r.pckt.typ, "len", r.pckt.len)