mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-06 08:20:25 +08:00
add pool & writer
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
blake2b "github.com/minio/blake2b-simd"
|
||||
)
|
||||
|
||||
@@ -57,15 +58,21 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
err = errors.New("data len < 12")
|
||||
return
|
||||
}
|
||||
|
||||
if p.DataSZ == 0 && len(p.Data) == 0 {
|
||||
p.DataSZ = binary.LittleEndian.Uint32(data[:4])
|
||||
p.Data = make([]byte, p.DataSZ)
|
||||
if int(p.DataSZ)+52 == len(data) {
|
||||
p.Data = data[52:]
|
||||
p.rembytes = 0
|
||||
} else {
|
||||
p.Data = make([]byte, p.DataSZ)
|
||||
p.rembytes = p.DataSZ
|
||||
}
|
||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||
p.Proto = uint8(pt)
|
||||
p.TTL = uint8(pt >> 8)
|
||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
p.rembytes = p.DataSZ
|
||||
}
|
||||
|
||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||
@@ -78,7 +85,10 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
copy(p.Dst, data[16:20])
|
||||
copy(p.Hash[:], data[20:52])
|
||||
}
|
||||
p.rembytes -= uint32(copy(p.Data[flags<<3:], data[52:]))
|
||||
|
||||
if p.rembytes > 0 {
|
||||
p.rembytes -= uint32(copy(p.Data[flags<<3:], data[52:]))
|
||||
}
|
||||
|
||||
complete = p.rembytes == 0
|
||||
|
||||
@@ -87,10 +97,10 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
|
||||
// Marshal 将自身数据编码为 []byte
|
||||
// offset 必须为 8 的倍数,表示偏移的 8 位
|
||||
func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, hasmore bool) []byte {
|
||||
func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
|
||||
p.TTL--
|
||||
if p.TTL == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if src != nil {
|
||||
@@ -105,20 +115,17 @@ func (p *Packet) Marshal(src net.IP, datasz uint32, offset uint16, dontfrag, has
|
||||
p.Flags = offset
|
||||
}
|
||||
|
||||
packet := make([]byte, 52+len(p.Data))
|
||||
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ)
|
||||
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
|
||||
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
|
||||
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
|
||||
binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
|
||||
copy(packet[12:16], p.Src.To4())
|
||||
copy(packet[16:20], p.Dst.To4())
|
||||
copy(packet[20:52], p.Hash[:])
|
||||
copy(packet[52:], p.Data)
|
||||
|
||||
// logrus.Debugln("[packet] marshaled packet:", hex.EncodeToString(packet))
|
||||
|
||||
return packet
|
||||
return helper.OpenWriterF(func(w *helper.Writer) {
|
||||
w.WriteUInt32(p.DataSZ)
|
||||
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
|
||||
w.WriteUInt16(p.SrcPort)
|
||||
w.WriteUInt16(p.DstPort)
|
||||
w.WriteUInt16(p.Flags)
|
||||
w.Write(p.Src.To4())
|
||||
w.Write(p.Dst.To4())
|
||||
w.Write(p.Hash[:])
|
||||
w.Write(p.Data)
|
||||
})
|
||||
}
|
||||
|
||||
// FillHash 生成 p.Data 的 Hash
|
||||
|
||||
@@ -125,13 +125,14 @@ func (l *Link) String() (n string) {
|
||||
// write 向 peer 发一个包
|
||||
func (l *Link) write(p *head.Packet, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) {
|
||||
var d []byte
|
||||
var cl func()
|
||||
if istransfer {
|
||||
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
|
||||
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
|
||||
}
|
||||
d = p.Marshal(nil, 0, 0, false, false)
|
||||
d, cl = p.Marshal(nil, 0, 0, false, false)
|
||||
} else {
|
||||
d = p.Marshal(l.me.me, datasz, offset, false, hasmore)
|
||||
d, cl = p.Marshal(l.me.me, datasz, offset, false, hasmore)
|
||||
}
|
||||
if d == nil {
|
||||
return 0, errors.New("[link] ttl exceeded")
|
||||
@@ -143,6 +144,7 @@ func (l *Link) write(p *head.Packet, datasz uint32, offset uint16, istransfer, h
|
||||
}
|
||||
logrus.Debugln("[link] write", len(d), "bytes data from ep", l.me.myconn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
|
||||
n, err = l.me.myconn.WriteToUDP(d, peerep)
|
||||
cl()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
logrus.Debugln("[link] deliver to pipe of", p.peerip)
|
||||
} else {
|
||||
m.nic.Write(packet.Data)
|
||||
logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to pipe of me")
|
||||
logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic")
|
||||
}
|
||||
default:
|
||||
logrus.Warnln("[link] recv unknown proto:", packet.Proto)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/head"
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/fumiama/WireGold/lower"
|
||||
"github.com/fumiama/water/waterutil"
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -38,6 +39,8 @@ type Me struct {
|
||||
nic lower.NICIO
|
||||
// 本机路由表
|
||||
router *Router
|
||||
// 本机发送缓冲区
|
||||
writer *helper.Writer
|
||||
// 本机未接收完全分片池
|
||||
recving map[[32]byte]*head.Packet
|
||||
recvmu sync.Mutex
|
||||
@@ -76,6 +79,7 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nic low
|
||||
m.srcport = srcport
|
||||
m.dstport = dstport
|
||||
m.mtu = mtu & 0xfff8
|
||||
m.writer = helper.SelectWriter()
|
||||
go m.initrecvpool()
|
||||
return
|
||||
}
|
||||
@@ -97,51 +101,27 @@ func (m *Me) Close() error {
|
||||
return m.nic.Close()
|
||||
}
|
||||
|
||||
func (m *Me) ListenFromNIC() {
|
||||
m.nic.Up()
|
||||
|
||||
// 双缓冲区
|
||||
buf := make([]byte, m.MTU()+68) // 增加报头长度与 TEA 冗余
|
||||
buf2 := make([]byte, m.MTU()+68) // 增加报头长度与 TEA 冗余
|
||||
|
||||
off := 0
|
||||
isrev := false
|
||||
for { // 从 NIC 发送
|
||||
var packet []byte
|
||||
if off > 0 && !isrev {
|
||||
packet = buf2
|
||||
} else {
|
||||
packet = buf
|
||||
}
|
||||
n, err := m.nic.Read(packet[off:])
|
||||
logrus.Debugln("[me] recv", n, "bytes to send from nic")
|
||||
if isrev {
|
||||
off = 0
|
||||
}
|
||||
if err != nil {
|
||||
logrus.Errorln("[me] send read from nic err:", err)
|
||||
break
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
packet = packet[:n]
|
||||
n, rem := m.sendAllSameDst(packet)
|
||||
for len(rem) > 20 && n > 0 {
|
||||
n, rem = m.sendAllSameDst(rem)
|
||||
}
|
||||
if len(rem) > 0 {
|
||||
logrus.Debugln("[me] remain", len(rem), "bytes to send")
|
||||
if off > 0 {
|
||||
off = copy(buf, rem)
|
||||
isrev = true
|
||||
} else {
|
||||
off = copy(buf2, rem)
|
||||
}
|
||||
} else {
|
||||
off = 0
|
||||
}
|
||||
func (m *Me) Write(packet []byte) (n int, err error) {
|
||||
m.writer.Write(packet)
|
||||
packet = m.writer.Bytes()
|
||||
logrus.Debugln("[me] writer eating", len(packet), "bytes...")
|
||||
n, packet = m.sendAllSameDst(packet)
|
||||
if len(packet) > 0 {
|
||||
w := helper.SelectWriter()
|
||||
w.Write(packet)
|
||||
helper.PutWriter(m.writer)
|
||||
m.writer = w
|
||||
logrus.Debugln("[me] writer remain", w.Len(), "bytes")
|
||||
} else if n > 0 {
|
||||
m.writer.Reset()
|
||||
logrus.Debugln("[me] writer becomes empty")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Me) ListenFromNIC() (written int64, err error) {
|
||||
m.nic.Up()
|
||||
return io.Copy(m, m.nic)
|
||||
}
|
||||
|
||||
type PacketID [2]byte
|
||||
@@ -172,28 +152,25 @@ func (m *Me) sendAllSameDst(packet []byte) (n int, rem []byte) {
|
||||
}
|
||||
}
|
||||
p := newpacketid(rem)
|
||||
for len(rem) > 20 && p.issame(rem) {
|
||||
totl := waterutil.IPv4TotalLength(rem)
|
||||
if int(totl) > len(rem) {
|
||||
suffix := make([]byte, int(totl)-len(rem))
|
||||
_, err := io.ReadFull(m.nic, suffix)
|
||||
if err != nil {
|
||||
return len(packet), nil
|
||||
}
|
||||
packet = append(packet, suffix...)
|
||||
n = len(packet)
|
||||
ptr := rem
|
||||
i := 0
|
||||
for len(ptr) > 20 && p.issame(ptr) {
|
||||
totl := waterutil.IPv4TotalLength(ptr)
|
||||
if int(totl) > len(ptr) {
|
||||
break
|
||||
}
|
||||
n += int(totl)
|
||||
rem = packet[n:]
|
||||
i += int(totl)
|
||||
ptr = rem[i:]
|
||||
logrus.Debugln("[me] wrap", totl, "bytes packet to send together")
|
||||
}
|
||||
if n == 0 {
|
||||
if i == 0 {
|
||||
return
|
||||
}
|
||||
packet = packet[:n]
|
||||
n += i
|
||||
packet = rem[:i]
|
||||
rem = rem[i:]
|
||||
dst := waterutil.IPv4Destination(packet)
|
||||
logrus.Debugln("[me] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())))
|
||||
logrus.Debugln("[me] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())), "remain:", len(rem), "bytes")
|
||||
lnk := m.router.NextHop(dst.String())
|
||||
if lnk == nil {
|
||||
logrus.Warnln("[me] drop packet: nil nexthop")
|
||||
|
||||
32
helper/pool.go
Normal file
32
helper/pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// https://github.com/Mrs4s/MiraiGo/blob/master/binary/pool.go
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Writer)
|
||||
},
|
||||
}
|
||||
|
||||
// SelectWriter 从池中取出一个 Writer
|
||||
func SelectWriter() *Writer {
|
||||
// 因为 bufferPool 定义有 New 函数
|
||||
// 所以 bufferPool.Get() 永不为 nil
|
||||
// 不用判空
|
||||
return bufferPool.Get().(*Writer)
|
||||
}
|
||||
|
||||
// PutWriter 将 Writer 放回池中
|
||||
func PutWriter(w *Writer) {
|
||||
// See https://golang.org/issue/23199
|
||||
const maxSize = 1 << 16
|
||||
if (*bytes.Buffer)(w).Cap() < maxSize { // 对于大Buffer直接丢弃
|
||||
w.Reset()
|
||||
bufferPool.Put(w)
|
||||
}
|
||||
}
|
||||
123
helper/writer.go
Normal file
123
helper/writer.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package helper
|
||||
|
||||
// https://github.com/Mrs4s/MiraiGo/blob/master/binary/writer.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// Writer 写入
|
||||
type Writer bytes.Buffer
|
||||
|
||||
func NewWriterF(f func(writer *Writer)) []byte {
|
||||
w := SelectWriter()
|
||||
f(w)
|
||||
b := append([]byte(nil), w.Bytes()...)
|
||||
w.put()
|
||||
return b
|
||||
}
|
||||
|
||||
// OpenWriterF must call func cl to close
|
||||
func OpenWriterF(f func(*Writer)) (b []byte, cl func()) {
|
||||
w := SelectWriter()
|
||||
f(w)
|
||||
return w.Bytes(), w.put
|
||||
}
|
||||
|
||||
func (w *Writer) FillUInt16() (pos int) {
|
||||
pos = w.Len()
|
||||
(*bytes.Buffer)(w).Write([]byte{0, 0})
|
||||
return
|
||||
}
|
||||
|
||||
func (w *Writer) WriteUInt16At(pos int, v uint16) {
|
||||
newdata := (*bytes.Buffer)(w).Bytes()[pos:]
|
||||
binary.LittleEndian.PutUint16(newdata, v)
|
||||
}
|
||||
|
||||
func (w *Writer) FillUInt32() (pos int) {
|
||||
pos = w.Len()
|
||||
(*bytes.Buffer)(w).Write([]byte{0, 0, 0, 0})
|
||||
return
|
||||
}
|
||||
|
||||
func (w *Writer) WriteUInt32At(pos int, v uint32) {
|
||||
newdata := (*bytes.Buffer)(w).Bytes()[pos:]
|
||||
binary.LittleEndian.PutUint32(newdata, v)
|
||||
}
|
||||
|
||||
func (w *Writer) Write(b []byte) (n int, err error) {
|
||||
return (*bytes.Buffer)(w).Write(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteHex(h string) {
|
||||
b, _ := hex.DecodeString(h)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteByte(b byte) error {
|
||||
return (*bytes.Buffer)(w).WriteByte(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteUInt16(v uint16) {
|
||||
b := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(b, v)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteUInt32(v uint32) {
|
||||
b := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(b, v)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteUInt64(v uint64) {
|
||||
b := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(b, v)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteString(v string) {
|
||||
w.WriteUInt32(uint32(len(v) + 4))
|
||||
(*bytes.Buffer)(w).WriteString(v)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteStringShort(v string) {
|
||||
w.WriteUInt16(uint16(len(v)))
|
||||
(*bytes.Buffer)(w).WriteString(v)
|
||||
}
|
||||
|
||||
func (w *Writer) WriteBool(b bool) {
|
||||
if b {
|
||||
w.WriteByte(0x01)
|
||||
} else {
|
||||
w.WriteByte(0x00)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Writer) WriteBytesShort(data []byte) {
|
||||
w.WriteUInt16(uint16(len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (w *Writer) Len() int {
|
||||
return (*bytes.Buffer)(w).Len()
|
||||
}
|
||||
|
||||
func (w *Writer) Bytes() []byte {
|
||||
return (*bytes.Buffer)(w).Bytes()
|
||||
}
|
||||
|
||||
func (w *Writer) Reset() {
|
||||
(*bytes.Buffer)(w).Reset()
|
||||
}
|
||||
|
||||
func (w *Writer) Grow(n int) {
|
||||
(*bytes.Buffer)(w).Grow(n)
|
||||
}
|
||||
|
||||
func (w *Writer) put() {
|
||||
PutWriter(w)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
base14 "github.com/fumiama/go-base16384"
|
||||
curve "github.com/fumiama/go-x25519"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/fumiama/WireGold/config"
|
||||
"github.com/fumiama/WireGold/gold/link"
|
||||
@@ -53,7 +54,10 @@ func (wg *WG) Start(srcport, destport, mtu uint16) {
|
||||
|
||||
func (wg *WG) Run(srcport, destport, mtu uint16) {
|
||||
wg.init(srcport, destport, mtu)
|
||||
wg.me.ListenFromNIC()
|
||||
_, err := wg.me.ListenFromNIC()
|
||||
if err != nil {
|
||||
logrus.Panicln(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (wg *WG) Stop() {
|
||||
|
||||
Reference in New Issue
Block a user