diff --git a/gold/head/packet.go b/gold/head/packet.go index 5217603..53a96a7 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -6,6 +6,7 @@ import ( "net" "unsafe" + "github.com/fumiama/WireGold/helper" blake2b "github.com/minio/blake2b-simd" ) @@ -57,15 +58,21 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { err = errors.New("data len < 12") return } + if p.DataSZ == 0 && len(p.Data) == 0 { p.DataSZ = binary.LittleEndian.Uint32(data[:4]) - p.Data = make([]byte, p.DataSZ) + if int(p.DataSZ)+52 == len(data) { + p.Data = data[52:] + p.rembytes = 0 + } else { + p.Data = make([]byte, p.DataSZ) + p.rembytes = p.DataSZ + } pt := binary.LittleEndian.Uint16(data[4:6]) p.Proto = uint8(pt) p.TTL = uint8(pt >> 8) p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) p.DstPort = binary.LittleEndian.Uint16(data[8:10]) - p.rembytes = p.DataSZ } flags := binary.LittleEndian.Uint16(data[10:12]) @@ -78,7 +85,10 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { copy(p.Dst, data[16:20]) copy(p.Hash[:], data[20:52]) } - p.rembytes -= uint32(copy(p.Data[flags<<3:], data[52:])) + + if p.rembytes > 0 { + p.rembytes -= uint32(copy(p.Data[flags<<3:], data[52:])) + } complete = p.rembytes == 0 @@ -87,10 +97,10 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { // Marshal 将自身数据编码为 []byte // offset 必须为 8 的倍数,表示偏移的 8 位 -func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, hasmore bool) []byte { +func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) { p.TTL-- if p.TTL == 0 { - return nil + return nil, nil } if src != nil { @@ -105,20 +115,17 @@ func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, has p.Flags = offset } - packet := make([]byte, 52+len(p.Data)) - binary.LittleEndian.PutUint32(packet[:4], p.DataSZ) - binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto)) - binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort) - binary.LittleEndian.PutUint16(packet[8:10], p.DstPort) - binary.LittleEndian.PutUint16(packet[10:12], p.Flags) - copy(packet[12:16], p.Src.To4()) - copy(packet[16:20], p.Dst.To4()) - copy(packet[20:52], p.Hash[:]) - copy(packet[52:], p.Data) - - // logrus.Debugln("[packet] marshaled packet:", hex.EncodeToString(packet)) - - return packet + return helper.OpenWriterF(func(w *helper.Writer) { + w.WriteUInt32(p.DataSZ) + w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto)) + w.WriteUInt16(p.SrcPort) + w.WriteUInt16(p.DstPort) + w.WriteUInt16(p.Flags) + w.Write(p.Src.To4()) + w.Write(p.Dst.To4()) + w.Write(p.Hash[:]) + w.Write(p.Data) + }) } // FillHash 生成 p.Data 的 Hash diff --git a/gold/link/link.go b/gold/link/link.go index 4a45d98..f4b7174 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -125,13 +125,14 @@ func (l *Link) String() (n string) { // write 向 peer 发一个包 func (l *Link) write(p *head.Packet, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) { var d []byte + var cl func() if istransfer { if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) { return len(p.Data), errors.New("drop dont fragmnet big trans packet") } - d = p.Marshal(nil, 0, 0, false, false) + d, cl = p.Marshal(nil, 0, 0, false, false) } else { - d = p.Marshal(l.me.me, datasz, offset, false, hasmore) + d, cl = p.Marshal(l.me.me, datasz, offset, false, hasmore) } if d == nil { return 0, errors.New("[link] ttl exceeded") @@ -143,6 +144,7 @@ func (l *Link) write(p *head.Packet, datasz uint32, offset uint16, istransfer, h } logrus.Debugln("[link] write", len(d), "bytes data from ep", l.me.myconn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) n, err = l.me.myconn.WriteToUDP(d, peerep) + cl() } return } diff --git a/gold/link/listen.go b/gold/link/listen.go index d7f7923..23e8126 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -69,7 +69,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) { logrus.Debugln("[link] deliver to pipe of", p.peerip) } else { m.nic.Write(packet.Data) - logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to pipe of me") + logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic") } default: logrus.Warnln("[link] recv unknown proto:", packet.Proto) diff --git a/gold/link/me.go b/gold/link/me.go index ca9a5a2..8784941 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/fumiama/WireGold/gold/head" + "github.com/fumiama/WireGold/helper" "github.com/fumiama/WireGold/lower" "github.com/fumiama/water/waterutil" "github.com/sirupsen/logrus" @@ -38,6 +39,8 @@ type Me struct { nic lower.NICIO // 本机路由表 router *Router + // 本机发送缓冲区 + writer *helper.Writer // 本机未接收完全分片池 recving map[[32]byte]*head.Packet recvmu sync.Mutex @@ -76,6 +79,7 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nic low m.srcport = srcport m.dstport = dstport m.mtu = mtu & 0xfff8 + m.writer = helper.SelectWriter() go m.initrecvpool() return } @@ -97,51 +101,27 @@ func (m *Me) Close() error { return m.nic.Close() } -func (m *Me) ListenFromNIC() { - m.nic.Up() - - // 双缓冲区 - buf := make([]byte, m.MTU()+68) // 增加报头长度与 TEA 冗余 - buf2 := make([]byte, m.MTU()+68) // 增加报头长度与 TEA 冗余 - - off := 0 - isrev := false - for { // 从 NIC 发送 - var packet []byte - if off > 0 && !isrev { - packet = buf2 - } else { - packet = buf - } - n, err := m.nic.Read(packet[off:]) - logrus.Debugln("[me] recv", n, "bytes to send from nic") - if isrev { - off = 0 - } - if err != nil { - logrus.Errorln("[me] send read from nic err:", err) - break - } - if n == 0 { - continue - } - packet = packet[:n] - n, rem := m.sendAllSameDst(packet) - for len(rem) > 20 && n > 0 { - n, rem = m.sendAllSameDst(rem) - } - if len(rem) > 0 { - logrus.Debugln("[me] remain", len(rem), "bytes to send") - if off > 0 { - off = copy(buf, rem) - isrev = true - } else { - off = copy(buf2, rem) - } - } else { - off = 0 - } +func (m *Me) Write(packet []byte) (n int, err error) { + m.writer.Write(packet) + packet = m.writer.Bytes() + logrus.Debugln("[me] writer eating", len(packet), "bytes...") + n, packet = m.sendAllSameDst(packet) + if len(packet) > 0 { + w := helper.SelectWriter() + w.Write(packet) + helper.PutWriter(m.writer) + m.writer = w + logrus.Debugln("[me] writer remain", w.Len(), "bytes") + } else if n > 0 { + m.writer.Reset() + logrus.Debugln("[me] writer becomes empty") } + return +} + +func (m *Me) ListenFromNIC() (written int64, err error) { + m.nic.Up() + return io.Copy(m, m.nic) } type PacketID [2]byte @@ -172,28 +152,25 @@ func (m *Me) sendAllSameDst(packet []byte) (n int, rem []byte) { } } p := newpacketid(rem) - for len(rem) > 20 && p.issame(rem) { - totl := waterutil.IPv4TotalLength(rem) - if int(totl) > len(rem) { - suffix := make([]byte, int(totl)-len(rem)) - _, err := io.ReadFull(m.nic, suffix) - if err != nil { - return len(packet), nil - } - packet = append(packet, suffix...) - n = len(packet) + ptr := rem + i := 0 + for len(ptr) > 20 && p.issame(ptr) { + totl := waterutil.IPv4TotalLength(ptr) + if int(totl) > len(ptr) { break } - n += int(totl) - rem = packet[n:] + i += int(totl) + ptr = rem[i:] logrus.Debugln("[me] wrap", totl, "bytes packet to send together") } - if n == 0 { + if i == 0 { return } - packet = packet[:n] + n += i + packet = rem[:i] + rem = rem[i:] dst := waterutil.IPv4Destination(packet) - logrus.Debugln("[me] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort()))) + logrus.Debugln("[me] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), "remain:", len(rem), "bytes") lnk := m.router.NextHop(dst.String()) if lnk == nil { logrus.Warnln("[me] drop packet: nil nexthop") diff --git a/helper/pool.go b/helper/pool.go new file mode 100644 index 0000000..c2f4558 --- /dev/null +++ b/helper/pool.go @@ -0,0 +1,32 @@ +package helper + +import ( + "bytes" + "sync" +) + +// https://github.com/Mrs4s/MiraiGo/blob/master/binary/pool.go + +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(Writer) + }, +} + +// SelectWriter 从池中取出一个 Writer +func SelectWriter() *Writer { + // 因为 bufferPool 定义有 New 函数 + // 所以 bufferPool.Get() 永不为 nil + // 不用判空 + return bufferPool.Get().(*Writer) +} + +// PutWriter 将 Writer 放回池中 +func PutWriter(w *Writer) { + // See https://golang.org/issue/23199 + const maxSize = 1 << 16 + if (*bytes.Buffer)(w).Cap() < maxSize { // 对于大Buffer直接丢弃 + w.Reset() + bufferPool.Put(w) + } +} diff --git a/helper/writer.go b/helper/writer.go new file mode 100644 index 0000000..f36a9fd --- /dev/null +++ b/helper/writer.go @@ -0,0 +1,123 @@ +package helper + +// https://github.com/Mrs4s/MiraiGo/blob/master/binary/writer.go + +import ( + "bytes" + "encoding/binary" + "encoding/hex" +) + +// Writer 写入 +type Writer bytes.Buffer + +func NewWriterF(f func(writer *Writer)) []byte { + w := SelectWriter() + f(w) + b := append([]byte(nil), w.Bytes()...) + w.put() + return b +} + +// OpenWriterF must call func cl to close +func OpenWriterF(f func(*Writer)) (b []byte, cl func()) { + w := SelectWriter() + f(w) + return w.Bytes(), w.put +} + +func (w *Writer) FillUInt16() (pos int) { + pos = w.Len() + (*bytes.Buffer)(w).Write([]byte{0, 0}) + return +} + +func (w *Writer) WriteUInt16At(pos int, v uint16) { + newdata := (*bytes.Buffer)(w).Bytes()[pos:] + binary.LittleEndian.PutUint16(newdata, v) +} + +func (w *Writer) FillUInt32() (pos int) { + pos = w.Len() + (*bytes.Buffer)(w).Write([]byte{0, 0, 0, 0}) + return +} + +func (w *Writer) WriteUInt32At(pos int, v uint32) { + newdata := (*bytes.Buffer)(w).Bytes()[pos:] + binary.LittleEndian.PutUint32(newdata, v) +} + +func (w *Writer) Write(b []byte) (n int, err error) { + return (*bytes.Buffer)(w).Write(b) +} + +func (w *Writer) WriteHex(h string) { + b, _ := hex.DecodeString(h) + w.Write(b) +} + +func (w *Writer) WriteByte(b byte) error { + return (*bytes.Buffer)(w).WriteByte(b) +} + +func (w *Writer) WriteUInt16(v uint16) { + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, v) + w.Write(b) +} + +func (w *Writer) WriteUInt32(v uint32) { + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, v) + w.Write(b) +} + +func (w *Writer) WriteUInt64(v uint64) { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, v) + w.Write(b) +} + +func (w *Writer) WriteString(v string) { + w.WriteUInt32(uint32(len(v) + 4)) + (*bytes.Buffer)(w).WriteString(v) +} + +func (w *Writer) WriteStringShort(v string) { + w.WriteUInt16(uint16(len(v))) + (*bytes.Buffer)(w).WriteString(v) +} + +func (w *Writer) WriteBool(b bool) { + if b { + w.WriteByte(0x01) + } else { + w.WriteByte(0x00) + } +} + +func (w *Writer) WriteBytesShort(data []byte) { + w.WriteUInt16(uint16(len(data))) + w.Write(data) +} + +func (w *Writer) Len() int { + return (*bytes.Buffer)(w).Len() +} + +func (w *Writer) Bytes() []byte { + return (*bytes.Buffer)(w).Bytes() +} + +func (w *Writer) Reset() { + (*bytes.Buffer)(w).Reset() +} + +func (w *Writer) Grow(n int) { + (*bytes.Buffer)(w).Grow(n) +} + +func (w *Writer) put() { + PutWriter(w) +} diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index f62f142..28561c1 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -6,6 +6,7 @@ import ( base14 "github.com/fumiama/go-base16384" curve "github.com/fumiama/go-x25519" + "github.com/sirupsen/logrus" "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/gold/link" @@ -53,7 +54,10 @@ func (wg *WG) Start(srcport, destport, mtu uint16) { func (wg *WG) Run(srcport, destport, mtu uint16) { wg.init(srcport, destport, mtu) - wg.me.ListenFromNIC() + _, err := wg.me.ListenFromNIC() + if err != nil { + logrus.Panicln(err) + } } func (wg *WG) Stop() {