diff --git a/config/cfg.go b/config/cfg.go index 319b93c..352fe15 100644 --- a/config/cfg.go +++ b/config/cfg.go @@ -15,7 +15,7 @@ type Config struct { PrivateKey string `yaml:"PrivateKey"` Network string `yaml:"Network"` // Network udp, tcp or ws (WIP) EndPoint string `yaml:"EndPoint"` - MTU int64 `yaml:"MTU"` + MTU int64 `yaml:"MTU"` // MTU of nic (will minus packet header len) SpeedLoop uint16 `yaml:"SpeedLoop"` Mask uint64 `yaml:"Mask"` // Mask 是异或报文所用掩码, 必须保证各端统一 Peers []Peer `yaml:"Peers"` @@ -34,7 +34,7 @@ type Peer struct { AllowTrans bool `yaml:"AllowTrans"` UseZstd bool `yaml:"UseZstd"` DoublePacket bool `yaml:"DoublePacket"` - MTU int64 `yaml:"MTU"` + MTU int64 `yaml:"MTU"` // MTU of PDU passed to p2p MTURandomRange int64 `yaml:"MTURandomRange"` } diff --git a/gold/link/me.go b/gold/link/me.go index 30ccb81..8ed8b68 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -39,7 +39,7 @@ type Me struct { // 本机监听的连接端点, 也用于向对端直接发送报文 conn p2p.Conn // 本机网卡 - nic lower.NICIO + nic *lower.NICIO // 本机路由表 router *Router // 本机未接收完全分片池 @@ -60,11 +60,17 @@ type MyConfig struct { Network string NetworkConfigs []any PrivateKey *[32]byte - NIC lower.NICIO + NICConfig *NICConfig SrcPort, DstPort, MTU, SpeedLoop uint16 Mask uint64 } +type NICConfig struct { + IP net.IP + SubNet *net.IPNet + CIDRs []string +} + // NewMe 设置本机参数 func NewMe(cfg *MyConfig) (m Me) { m.privKey = *cfg.PrivateKey @@ -89,7 +95,6 @@ func NewMe(cfg *MyConfig) (m Me) { panic(err) } m.connections = make(map[string]*Link) - m.nic = cfg.NIC m.router = &Router{ list: make([]*net.IPNet, 1, 16), table: make(map[string]*Link, 16), @@ -98,7 +103,13 @@ func NewMe(cfg *MyConfig) (m Me) { m.router.SetDefault(nil) m.srcport = cfg.SrcPort m.dstport = cfg.DstPort - m.mtu = cfg.MTU & 0xfff8 + m.mtu = (cfg.MTU - head.PacketHeadLen) & 0xfff8 + if cfg.NICConfig != nil { + m.nic = lower.NewNIC( + cfg.NICConfig.IP, cfg.NICConfig.SubNet, + strconv.FormatUint(uint64(m.MTU()), 10), cfg.NICConfig.CIDRs..., + ) + } m.speedloop = cfg.SpeedLoop if m.speedloop == 0 { m.speedloop = 4096 diff --git a/gold/link/peer.go b/gold/link/peer.go index 169f93a..102f3a4 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -33,7 +33,7 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) { if ok { return } - if m.mtu == 0 { + if cfg.MTU == 0 { panic("invalid mtu for peer " + cfg.PeerIP) } l = &Link{ diff --git a/lower/nic.go b/lower/nic.go index 2d0f1c2..0bdeb1b 100644 --- a/lower/nic.go +++ b/lower/nic.go @@ -1,7 +1,6 @@ package lower import ( - "io" "net" "os" "os/exec" @@ -11,14 +10,8 @@ import ( "github.com/sirupsen/logrus" ) -type NICIO interface { - io.ReadWriteCloser - Up() - Down() -} - -// NIC 虚拟网卡 -type NIC struct { +// NICIO 虚拟网卡 +type NICIO struct { ifce *water.Interface ip net.IP subnet *net.IPNet @@ -31,7 +24,7 @@ type NIC struct { // 网卡地址为 ip, 所属子网为 subnet // 以本网卡为下一跳的所有子网为 cidrs // cidrs 不包括本网卡 subnet -func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) NICIO { +func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) *NICIO { ifce, err := water.New(water.Config{DeviceType: water.TUN}) if err != nil { logrus.Error(err) @@ -41,7 +34,7 @@ func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) NICIO { if bitsn != 32 { panic("mask len " + strconv.Itoa(bitsn) + " is not supported") } - n := &NIC{ + n := &NICIO{ ifce: ifce, ip: ip, subnet: subnet, @@ -53,16 +46,16 @@ func NewNIC(ip net.IP, subnet *net.IPNet, mtu string, cidrs ...string) NICIO { } // Read 匹配 PacketsIO Interface -func (nc *NIC) Read(buf []byte) (int, error) { +func (nc *NICIO) Read(buf []byte) (int, error) { return nc.ifce.Read(buf) } -func (nc *NIC) Write(packet []byte) (int, error) { +func (nc *NICIO) Write(packet []byte) (int, error) { return nc.ifce.Write(packet) } // Close 关闭网卡 -func (n *NIC) Close() error { +func (n *NICIO) Close() error { return n.ifce.Close() } diff --git a/lower/tun_darwin.go b/lower/tun_darwin.go index 9a72df4..c0e50da 100644 --- a/lower/tun_darwin.go +++ b/lower/tun_darwin.go @@ -5,7 +5,7 @@ package lower import "net" -func (n *NIC) Up() { +func (n *NICIO) Up() { execute("ifconfig", n.ifce.Name(), "mtu", n.mtu) // max: 9159 execute( "ifconfig", n.ifce.Name(), @@ -19,7 +19,7 @@ func (n *NIC) Up() { } } -func (n *NIC) Down() { +func (n *NICIO) Down() { execute("route", "delete", n.subnet.String(), "-interface", n.ifce.Name()) for _, c := range n.cidrs { execute("route", "delete", c, "-interface", n.ifce.Name()) diff --git a/lower/tun_linux.go b/lower/tun_linux.go index 145ecd1..8475b3a 100644 --- a/lower/tun_linux.go +++ b/lower/tun_linux.go @@ -3,7 +3,7 @@ package lower -func (n *NIC) Up() { +func (n *NICIO) Up() { execute("/sbin/ip", "link", "set", "dev", n.ifce.Name(), "mtu", n.mtu) execute("/sbin/ip", "addr", "add", n.rawipnet, "dev", n.ifce.Name()) execute("/sbin/ip", "link", "set", "dev", n.ifce.Name(), "up") @@ -12,7 +12,7 @@ func (n *NIC) Up() { } } -func (n *NIC) Down() { +func (n *NICIO) Down() { for _, c := range n.cidrs { execute("/sbin/ip", "route", "del", c, "dev", n.ifce.Name()) } diff --git a/lower/tun_stub.go b/lower/tun_stub.go index 362bfb2..f102671 100644 --- a/lower/tun_stub.go +++ b/lower/tun_stub.go @@ -3,10 +3,10 @@ package lower -func (n *NIC) Up() { +func (n *NICIO) Up() { panic("not support lower on this os now") } -func (n *NIC) Down() { +func (n *NICIO) Down() { panic("not support lower on this os now") } diff --git a/lower/tun_windows.go b/lower/tun_windows.go index f9a2851..919d4c1 100644 --- a/lower/tun_windows.go +++ b/lower/tun_windows.go @@ -5,7 +5,7 @@ package lower import "net" -func (n *NIC) Up() { +func (n *NICIO) Up() { execute("cmd", "/c", "netsh interface ip set address name=\""+n.ifce.Name()+"\" source=static addr=\""+n.ip.String()+"\" mask=\""+(net.IP)(n.subnet.Mask).String()+"\" gateway=none") execute("cmd", "/c", "netsh interface ipv4 set subinterface \""+n.ifce.Name()+"\" mtu="+n.mtu) for _, c := range n.cidrs { @@ -17,7 +17,7 @@ func (n *NIC) Up() { } } -func (n *NIC) Down() { +func (n *NICIO) Down() { // execute("netsh", "interface", "set", "interface", n.ifce.Name(), "disabled") for _, c := range n.cidrs { ip, _, err := net.ParseCIDR(c) diff --git a/main.go b/main.go index e8cce83..c511b73 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "github.com/sirupsen/logrus" "github.com/fumiama/WireGold/config" + "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/helper" "github.com/fumiama/WireGold/upper" "github.com/fumiama/WireGold/upper/services/wg" @@ -144,8 +145,8 @@ func main() { if c.EndPoint == "" { displayHelp("nil endpoint") } - if c.MTU == 0 { - displayHelp("nil mtu") + if c.MTU <= head.PacketHeadLen { + displayHelp("invalid mtu") } w, err := wg.NewWireGold(&c) if err != nil { diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index 1cf81df..5a8b33a 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -17,7 +17,6 @@ import ( "github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/gold/link" "github.com/fumiama/WireGold/helper" - "github.com/fumiama/WireGold/lower" ) const suffix32 = "㴄" @@ -104,12 +103,16 @@ func (wg *WG) init(srcport, dstport uint16) { MyEndpoint: wg.c.EndPoint, Network: wg.c.Network, PrivateKey: &wg.key, - NIC: lower.NewNIC(myip, mysubnet, strconv.FormatInt(wg.c.MTU, 10), cidrs...), - SrcPort: srcport, - DstPort: dstport, - MTU: uint16(wg.c.MTU), - SpeedLoop: wg.c.SpeedLoop, - Mask: wg.c.Mask, + NICConfig: &link.NICConfig{ + IP: myip, + SubNet: mysubnet, + CIDRs: cidrs, + }, + SrcPort: srcport, + DstPort: dstport, + MTU: uint16(wg.c.MTU), + SpeedLoop: wg.c.SpeedLoop, + Mask: wg.c.Mask, }) for _, peer := range wg.c.Peers {