mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-04 23:40:26 +08:00
fix(tunnel): add seq to prevent order mismatch
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
type Tunnel struct {
|
||||
l *link.Link
|
||||
in chan []byte
|
||||
out chan []byte
|
||||
out chan *head.Packet
|
||||
outcache []byte
|
||||
peerip net.IP
|
||||
src uint16
|
||||
@@ -26,7 +27,7 @@ func Create(me *link.Me, peer string) (s Tunnel, err error) {
|
||||
s.l, err = me.Connect(peer)
|
||||
if err == nil {
|
||||
s.in = make(chan []byte, 4)
|
||||
s.out = make(chan []byte, 4)
|
||||
s.out = make(chan *head.Packet, 4)
|
||||
s.peerip = net.ParseIP(peer)
|
||||
} else {
|
||||
logrus.Errorln("[tunnel] create err:", err)
|
||||
@@ -62,7 +63,16 @@ func (s *Tunnel) Read(p []byte) (int, error) {
|
||||
if s.outcache != nil {
|
||||
d = s.outcache
|
||||
} else {
|
||||
d = <-s.out
|
||||
pkt := <-s.out
|
||||
if pkt == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
defer pkt.Put()
|
||||
if len(pkt.Data) < 4 {
|
||||
logrus.Warnln("[tunnel] unexpected packet data len", len(pkt.Data), "content", pkt.Data)
|
||||
return 0, io.EOF
|
||||
}
|
||||
d = pkt.Data[4:]
|
||||
}
|
||||
if d != nil {
|
||||
if len(p) >= len(d) {
|
||||
@@ -79,9 +89,12 @@ func (s *Tunnel) Read(p []byte) (int, error) {
|
||||
func (s *Tunnel) Stop() {
|
||||
s.l.Close()
|
||||
close(s.in)
|
||||
close(s.out)
|
||||
}
|
||||
|
||||
func (s *Tunnel) handleWrite() {
|
||||
seq := uint32(0)
|
||||
buf := make([]byte, s.mtu)
|
||||
for b := range s.in {
|
||||
end := 64
|
||||
endl := "..."
|
||||
@@ -95,27 +108,45 @@ func (s *Tunnel) handleWrite() {
|
||||
break
|
||||
}
|
||||
logrus.Debugln("[tunnel] writing", len(b), "bytes...")
|
||||
for len(b) > int(s.mtu) {
|
||||
logrus.Infoln("[tunnel] split buffer")
|
||||
_, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false)
|
||||
for len(b) > int(s.mtu)-4 {
|
||||
logrus.Infoln("[tunnel] seq", seq, "split buffer")
|
||||
binary.LittleEndian.PutUint32(buf[:4], seq)
|
||||
seq++
|
||||
copy(buf[4:], b[:s.mtu-4])
|
||||
_, err := s.l.WriteAndPut(
|
||||
head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf), false,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.Errorln("[tunnel] write err:", err)
|
||||
logrus.Errorln("[tunnel] seq", seq-1, "write err:", err)
|
||||
return
|
||||
}
|
||||
logrus.Debugln("[tunnel] write succeeded")
|
||||
b = b[s.mtu:]
|
||||
logrus.Debugln("[tunnel] seq", seq-1, "write succeeded")
|
||||
b = b[s.mtu-4:]
|
||||
}
|
||||
_, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false)
|
||||
binary.LittleEndian.PutUint32(buf[:4], seq)
|
||||
seq++
|
||||
copy(buf[4:], b)
|
||||
_, err := s.l.WriteAndPut(
|
||||
head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf[:len(b)+4]), false,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.Errorln("[tunnel] write err:", err)
|
||||
logrus.Errorln("[tunnel] seq", seq-1, "write err:", err)
|
||||
break
|
||||
}
|
||||
logrus.Debugln("[tunnel] write succeeded")
|
||||
logrus.Debugln("[tunnel] seq", seq-1, "write succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Tunnel) handleRead() {
|
||||
seq := uint32(0)
|
||||
seqmap := make(map[uint32]*head.Packet)
|
||||
for {
|
||||
if p, ok := seqmap[seq]; ok {
|
||||
logrus.Debugln("[tunnel] dispatch cached seq", seq)
|
||||
delete(seqmap, seq)
|
||||
seq++
|
||||
s.out <- p
|
||||
}
|
||||
p := s.l.Read()
|
||||
if p == nil {
|
||||
logrus.Errorln("[tunnel] read recv nil")
|
||||
@@ -128,7 +159,14 @@ func (s *Tunnel) handleRead() {
|
||||
endl = "."
|
||||
}
|
||||
logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Data[:end]), endl)
|
||||
s.out <- p.Data
|
||||
p.Put()
|
||||
recvseq := binary.LittleEndian.Uint32(p.Data[:4])
|
||||
if recvseq == seq {
|
||||
logrus.Debugln("[tunnel] dispatch seq", seq)
|
||||
seq++
|
||||
s.out <- p
|
||||
continue
|
||||
}
|
||||
seqmap[recvseq] = p
|
||||
logrus.Debugln("[tunnel] cache seq", recvseq)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
@@ -92,26 +93,47 @@ func TestTunnel(t *testing.T) {
|
||||
rand.Read(sendb)
|
||||
tunnme.Write(sendb)
|
||||
buf = make([]byte, 4096)
|
||||
tunnpeer.Read(buf)
|
||||
_, err = io.ReadFull(&tunnpeer, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(sendb) != string(buf) {
|
||||
t.Fatal("error: recv 4096 bytes data")
|
||||
}
|
||||
|
||||
sendb = make([]byte, 65535)
|
||||
rand.Read(sendb)
|
||||
n, _ := tunnme.Write(sendb)
|
||||
t.Log("write", n, "bytes")
|
||||
buf = make([]byte, 65535)
|
||||
n, _ = io.ReadFull(&tunnpeer, buf)
|
||||
t.Log("read", n, "bytes")
|
||||
if string(sendb) != string(buf) {
|
||||
t.Fatal("error: recv 65535 bytes data")
|
||||
t.Log("expect", hex.EncodeToString(sendb))
|
||||
t.Log("got", hex.EncodeToString(buf))
|
||||
for i := 0; i < 32; i++ {
|
||||
rand.Read(sendb)
|
||||
n, _ := tunnme.Write(sendb)
|
||||
t.Log("loop", i, "write", n, "bytes")
|
||||
n, err = io.ReadFull(&tunnpeer, buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("loop", i, "read", n, "bytes")
|
||||
if string(sendb) != string(buf) {
|
||||
t.Fatal("loop", i, "error: recv 65535 bytes data")
|
||||
}
|
||||
}
|
||||
|
||||
tunnme.Stop()
|
||||
tunnpeer.Stop()
|
||||
rand.Read(sendb)
|
||||
tunnme.Write(sendb)
|
||||
rd := bytes.NewBuffer(nil)
|
||||
|
||||
tm := time.AfterFunc(time.Second*5, func() {
|
||||
tunnme.Stop()
|
||||
tunnpeer.Stop()
|
||||
})
|
||||
defer tm.Stop()
|
||||
|
||||
_, err = io.CopyBuffer(rd, &tunnpeer, make([]byte, 200))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(sendb) != rd.String() {
|
||||
t.Fatal("error: recv fragmented 4096 bytes data")
|
||||
}
|
||||
}
|
||||
|
||||
// logFormat specialize for go-cqhttp
|
||||
|
||||
Reference in New Issue
Block a user