1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-25 05:20:15 +08:00

all: feats & optimizes

- feat(me): add new cfg param SpeedLoop
- chore(send): remove unnecessary err chk in write
- perf(listen): use mutex instead of bool checking
This commit is contained in:
源文雨
2024-07-11 17:44:58 +09:00
parent d695f14498
commit c0f31a70c8
6 changed files with 97 additions and 46 deletions

43
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: golang-ci
on: [push, pull_request]
jobs:
golang-ci:
name: CI
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.x
uses: actions/setup-go@master
with:
go-version: ^1.20
- name: Check out code into the Go module directory
uses: actions/checkout@master
- name: Get dependencies
run: go mod tidy
- name: Build
run: go build -v ./...
- name: Test
run: go test $(go list ./...)
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.x
uses: actions/setup-go@master
with:
go-version: ^1.20
- name: Check out code into the Go module directory
uses: actions/checkout@master
- name: golangci-lint
uses: golangci/golangci-lint-action@master
with:
version: latest

View File

@@ -24,6 +24,7 @@ type Config struct {
PrivateKey string `yaml:"PrivateKey"` PrivateKey string `yaml:"PrivateKey"`
EndPoint string `yaml:"EndPoint"` EndPoint string `yaml:"EndPoint"`
MTU int64 `yaml:"MTU"` MTU int64 `yaml:"MTU"`
SpeedLoop uint16 `yaml:"SpeedLoop"`
Mask uint64 `yaml:"Mask"` // Mask 是异或报文所用掩码, 必须保证各端统一 Mask uint64 `yaml:"Mask"` // Mask 是异或报文所用掩码, 必须保证各端统一
Peers []Peer `yaml:"Peers"` Peers []Peer `yaml:"Peers"`
} }

View File

@@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -17,17 +18,17 @@ import (
"github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/gold/head"
) )
// 监听本机 endpoint // 监听本机 UDP endpoint
func (m *Me) listen() (conn *net.UDPConn, err error) { func (m *Me) listenudp() (conn *net.UDPConn, err error) {
conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.myend.String()))) conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.udpep.String())))
if err != nil { if err != nil {
return return
} }
m.myend = conn.LocalAddr() m.udpep = conn.LocalAddr()
logrus.Infoln("[listen] at", m.myend) logrus.Infoln("[listen] at", m.udpep)
go func() { go func() {
recvtotlcnt := 0 recvtotlcnt := uint64(0)
recvloopcnt := 0 recvloopcnt := uint16(0)
recvlooptime := time.Now().UnixMilli() recvlooptime := time.Now().UnixMilli()
n := runtime.NumCPU() n := runtime.NumCPU()
if n > 64 { if n > 64 {
@@ -35,42 +36,45 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
} }
logrus.Infoln("[listen] use cpu num:", n) logrus.Infoln("[listen] use cpu num:", n)
listenbuff := make([]byte, 65536*n) listenbuff := make([]byte, 65536*n)
hasntfinished := make([]bool, n) hasntfinished := make([]sync.Mutex, n)
for i := 0; err == nil; i++ { for i := 0; err == nil; i++ {
i %= n i %= n
for hasntfinished[i] { for !hasntfinished[i].TryLock() {
time.Sleep(time.Millisecond)
i++ i++
i %= n i %= n
if i == 0 { // looked up a full round
time.Sleep(time.Millisecond * 10)
}
} }
logrus.Debugln("[listen] lock index", i)
lbf := listenbuff[i*65536 : (i+1)*65536] lbf := listenbuff[i*65536 : (i+1)*65536]
n, addr, err := conn.ReadFromUDP(lbf) n, addr, err := conn.ReadFromUDP(lbf)
if err != nil { if err != nil {
logrus.Warnln("[listen] read from udp err, reconnect:", err) logrus.Warnln("[listen] read from udp err, reconnect:", err)
conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.myend.String()))) conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.udpep.String())))
if err != nil { if err != nil {
logrus.Errorln("[listen] reconnect udp err:", err) logrus.Errorln("[listen] reconnect udp err:", err)
return return
} }
hasntfinished[i].Unlock()
i-- i--
continue continue
} }
recvtotlcnt += n recvtotlcnt += uint64(n)
recvloopcnt++ recvloopcnt++
if recvloopcnt >= 4096 { if recvloopcnt%m.speedloop == 0 {
now := time.Now().UnixMilli() now := time.Now().UnixMilli()
logrus.Infof("[listen] recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime)) logrus.Infof("[listen] recv avg speed: %.2f KB/s", float64(recvtotlcnt)/float64(now-recvlooptime))
recvtotlcnt = 0 recvtotlcnt = 0
recvloopcnt = 0
recvlooptime = now recvlooptime = now
} }
packet := m.wait(lbf[:n]) packet := m.wait(lbf[:n])
if packet == nil { if packet == nil {
hasntfinished[i].Unlock()
i-- i--
continue continue
} }
hasntfinished[i] = true go m.listenthread(packet, addr, i, hasntfinished[i].Unlock)
go m.listenthread(packet, addr, i, func() { hasntfinished[i] = false })
} }
}() }()
return return

View File

@@ -27,16 +27,16 @@ type Me struct {
me net.IP me net.IP
// 本机子网 // 本机子网
subnet net.IPNet subnet net.IPNet
// 本机 endpoint // 本机 UDP endpoint
myend net.Addr udpep net.Addr
// 本机环回 link // 本机环回 link
loop *Link loop *Link
// 本机活跃的所有连接 // 本机活跃的所有连接
connections map[string]*Link connections map[string]*Link
// 读写同步锁 // 读写同步锁
connmapmu sync.RWMutex connmapmu sync.RWMutex
// 本机监听的 endpoint // 本机监听的 udp 连接, 用于向对端直接发送报文
myep *net.UDPConn udpconn *net.UDPConn
// 本机网卡 // 本机网卡
nic lower.NICIO nic lower.NICIO
// 本机路由表 // 本机路由表
@@ -46,25 +46,25 @@ type Me struct {
// 抗重放攻击记录池 // 抗重放攻击记录池
recved *ttl.Cache[uint64, bool] recved *ttl.Cache[uint64, bool]
// 本机上层配置 // 本机上层配置
srcport, dstport, mtu uint16 srcport, dstport, mtu, speedloop uint16
// 报头掩码 // 报头掩码
mask uint64 mask uint64
} }
type MyConfig struct { type MyConfig struct {
MyIPwithMask string MyIPwithMask string
MyEndpoint string MyEndpoint string
PrivateKey *[32]byte PrivateKey *[32]byte
NIC lower.NICIO NIC lower.NICIO
SrcPort, DstPort, MTU uint16 SrcPort, DstPort, MTU, SpeedLoop uint16
Mask uint64 Mask uint64
} }
// NewMe 设置本机参数 // NewMe 设置本机参数
func NewMe(cfg *MyConfig) (m Me) { func NewMe(cfg *MyConfig) (m Me) {
m.privKey = *cfg.PrivateKey m.privKey = *cfg.PrivateKey
var err error var err error
m.myend, err = net.ResolveUDPAddr("udp", cfg.MyEndpoint) m.udpep, err = net.ResolveUDPAddr("udp", cfg.MyEndpoint)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -74,7 +74,7 @@ func NewMe(cfg *MyConfig) (m Me) {
} }
m.me = ip m.me = ip
m.subnet = *cidr m.subnet = *cidr
m.myep, err = m.listen() m.udpconn, err = m.listenudp()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -96,6 +96,10 @@ func NewMe(cfg *MyConfig) (m Me) {
m.srcport = cfg.SrcPort m.srcport = cfg.SrcPort
m.dstport = cfg.DstPort m.dstport = cfg.DstPort
m.mtu = cfg.MTU & 0xfff8 m.mtu = cfg.MTU & 0xfff8
m.speedloop = cfg.SpeedLoop
if m.speedloop == 0 {
m.speedloop = 4096
}
m.mask = cfg.Mask m.mask = cfg.Mask
var buf [8]byte var buf [8]byte
binary.BigEndian.PutUint64(buf[:], m.mask) binary.BigEndian.PutUint64(buf[:], m.mask)

View File

@@ -98,23 +98,21 @@ func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz ui
if d == nil { if d == nil {
return 0, errors.New("[send] ttl exceeded") return 0, errors.New("[send] ttl exceeded")
} }
if err == nil { peerep := l.endpoint
peerep := l.endpoint if peerep == nil {
if peerep == nil { return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
}
bound := 64
endl := "..."
if len(d) < bound {
bound = len(d)
endl = "."
}
logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
d = l.me.xorenc(d)
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
n, err = l.me.myep.WriteToUDP(d, peerep)
cl()
} }
bound := 64
endl := "..."
if len(d) < bound {
bound = len(d)
endl = "."
}
logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.udpconn.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
logrus.Debugln("[send] data bytes", hex.EncodeToString(d[:bound]), endl)
d = l.me.xorenc(d)
logrus.Debugln("[send] data xored", hex.EncodeToString(d[:bound]), endl)
n, err = l.me.udpconn.WriteToUDP(d, peerep)
cl()
return return
} }

View File

@@ -97,6 +97,7 @@ func (wg *WG) init(srcport, dstport uint16) {
SrcPort: srcport, SrcPort: srcport,
DstPort: dstport, DstPort: dstport,
MTU: uint16(wg.c.MTU), MTU: uint16(wg.c.MTU),
SpeedLoop: wg.c.SpeedLoop,
Mask: wg.c.Mask, Mask: wg.c.Mask,
}) })