package tcp import ( "encoding/binary" "errors" "io" "net" "time" "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/internal/bin" "github.com/sirupsen/logrus" ) var ( ErrInvalidMagic = errors.New("invalid magic") ) type packetType uint8 const ( packetTypeKeepAlive packetType = iota packetTypeNormal packetTypeSubKeepAlive packetTypeTop ) var ( magicbuf = []byte("GET ") magic = binary.LittleEndian.Uint32(magicbuf) ) type packet struct { typ packetType len uint16 dat []byte io.ReaderFrom io.WriterTo } func (p *packet) pack() *net.Buffers { return &net.Buffers{magicbuf, bin.NewWriterF(func(w *bin.Writer) { w.WriteByte(byte(p.typ)) w.WriteUInt16(p.len) }).Trans(), p.dat} } func (p *packet) Read(_ []byte) (int, error) { panic("stub") } func (p *packet) Write(_ []byte) (int, error) { panic("stub") } func (p *packet) ReadFrom(r io.Reader) (n int64, err error) { var buf [4]byte cnt, err := io.ReadFull(r, buf[:]) n = int64(cnt) if err != nil { return } if binary.LittleEndian.Uint32(buf[:]) != magic { err = ErrInvalidMagic if config.ShowDebugLog { logrus.Debugf("[tcp] expect magic %08x but got %08x", magic, binary.LittleEndian.Uint32(buf[:])) } return } cnt, err = io.ReadFull(r, buf[:3]) n += int64(cnt) if err != nil { return } p.typ = packetType(buf[0]) p.len = binary.LittleEndian.Uint16(buf[1:3]) w := bin.SelectWriter() copied, err := io.CopyN(w, r, int64(p.len)) n += copied if err != nil { return } p.dat = w.ToBytes().Trans() return } func (p *packet) WriteTo(w io.Writer) (n int64, err error) { return io.Copy(w, p.pack()) } func isvalid(tcpconn *net.TCPConn, timeout time.Duration) (issub, ok bool) { pckt := packet{} stopch := make(chan struct{}) t := time.AfterFunc(timeout, func() { stopch <- struct{}{} }) var err error copych := make(chan struct{}) go func() { _, err = io.Copy(&pckt, tcpconn) copych <- struct{}{} }() select { case <-stopch: if config.ShowDebugLog { logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout") } return case <-copych: t.Stop() } if err != nil { if config.ShowDebugLog { logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err) } return } if pckt.typ != packetTypeKeepAlive && pckt.typ != packetTypeSubKeepAlive { if config.ShowDebugLog { logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr()) } return } if config.ShowDebugLog { logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr()) } return pckt.typ == packetTypeSubKeepAlive, true }