1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-23 12:00:34 +08:00

add 分片

This commit is contained in:
fumiama
2021-12-31 12:34:54 +08:00
parent 7a30bfb1e6
commit d108bb81b4
11 changed files with 336 additions and 147 deletions

View File

@@ -12,9 +12,11 @@ import (
// Packet 是发送和接收的最小单位 // Packet 是发送和接收的最小单位
type Packet struct { type Packet struct {
// Ver 协议版本
Ver uint16
// DataSZ len(Data) // DataSZ len(Data)
// 不得超过 65507-head 字节 // 不得超过 65507-head 字节
DataSZ uint32 DataSZ uint16
// Proto 详见 head // Proto 详见 head
Proto uint8 Proto uint8
// TTL is time to live // TTL is time to live
@@ -23,9 +25,11 @@ type Packet struct {
SrcPort uint16 SrcPort uint16
// DstPort 目的端口 // DstPort 目的端口
DstPort uint16 DstPort uint16
// Src 源 ip // Flags 高3位为标志(xDM)低13位为分片偏移
Flags uint16
// Src 源 ip (ipv4)
Src net.IP Src net.IP
// Dst 目的 ip // Dst 目的 ip (ipv4)
Dst net.IP Dst net.IP
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要 // Hash 使用 BLAKE2 生成加密前 Packet 的摘要
// 生成时 Hash 全 0 // 生成时 Hash 全 0
@@ -33,12 +37,15 @@ type Packet struct {
Hash [32]byte Hash [32]byte
// Data 承载的数据 // Data 承载的数据
Data []byte Data []byte
// 记录还有多少字节未到达
rembytes uint16
} }
// NewPacket 生成一个新包 // NewPacket 生成一个新包
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet { func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data) logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
return &Packet{ return &Packet{
Ver: 1,
Proto: proto, Proto: proto,
TTL: 16, TTL: 16,
SrcPort: srcPort, SrcPort: srcPort,
@@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
} }
// Unmarshal 将 data 的数据解码到自身 // Unmarshal 将 data 的数据解码到自身
func (p *Packet) Unmarshal(data []byte) error { func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
if len(data) < 12 { if len(data) < 12 {
return errors.New("data len < 12") err = errors.New("data len < 12")
return
} }
p.DataSZ = binary.LittleEndian.Uint32(data[:4]) if p.DataSZ == 0 && len(p.Data) == 0 {
pt := binary.LittleEndian.Uint16(data[4:6]) p.Ver = binary.LittleEndian.Uint16(data[:2])
p.Proto = uint8(pt) if p.Ver != 1 {
p.TTL = uint8(pt >> 8) err = errors.New("unknown protocol version")
p.SrcPort = binary.LittleEndian.Uint16(data[6:8]) return
p.DstPort = binary.LittleEndian.Uint16(data[8:10]) }
sdl := binary.LittleEndian.Uint16(data[10:12]) p.DataSZ = binary.LittleEndian.Uint16(data[2:4])
srclen := uint8(sdl) p.Data = make([]byte, p.DataSZ)
dstlen := uint8(sdl >> 8) pt := binary.LittleEndian.Uint16(data[4:6])
if len(data) < int(12+srclen+dstlen) { p.Proto = uint8(pt)
return errors.New("data src or dst len mismatch") p.TTL = uint8(pt >> 8)
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
p.rembytes = p.DataSZ
} }
if srclen > 0 {
p.Src = make(net.IP, srclen) p.Flags = binary.LittleEndian.Uint16(data[10:12])
copy(p.Src, data[12:12+srclen])
} p.Src = make(net.IP, 4)
if dstlen > 0 { copy(p.Src, data[12:16])
p.Dst = make(net.IP, dstlen) p.Dst = make(net.IP, 4)
copy(p.Dst, data[12+srclen:12+srclen+dstlen]) copy(p.Dst, data[16:20])
} copy(p.Hash[:], data[20:52])
copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32]) p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:]))
p.Data = data[12+srclen+dstlen+32:]
return nil complete = p.rembytes == 0
return
} }
// Marshal 将自身数据编码为 []byte // Marshal 将自身数据编码为 []byte
func (p *Packet) Marshal(src net.IP) []byte { // offset 必须为 8 的倍数,表示偏移的 8 位
func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte {
p.TTL-- p.TTL--
if p.TTL == 0 { if p.TTL == 0 {
return nil return nil
} }
p.DataSZ = uint32(len(p.Data)) p.DataSZ = uint16(len(p.Data))
if src != nil { if src != nil {
p.Src = src p.Src = src
offset >>= 3
if dontfrag {
offset |= 0x4000
}
if hasmore {
offset |= 0x2000
}
p.Flags = offset
} }
packet := make([]byte, 52+len(p.Data)) packet := make([]byte, 52+len(p.Data))
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ) binary.LittleEndian.PutUint16(packet[:2], p.Ver)
binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ)
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto)) binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort) binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort) binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
binary.LittleEndian.PutUint16(packet[10:12], 0x0404) binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
copy(packet[12:16], p.Src.To4()) copy(packet[12:16], p.Src.To4())
copy(packet[16:20], p.Dst.To4()) copy(packet[16:20], p.Dst.To4())
copy(packet[20:52], p.Hash[:]) copy(packet[20:52], p.Hash[:])

View File

@@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet {
// Write 向 peer 发包 // Write 向 peer 发包
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) { func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
if len(p.Data) <= (32768 - 64) { if len(p.Data) <= int(l.me.mtu) {
return l.write(p, istransfer) return l.write(p, 0, istransfer, false)
} }
data := p.Data data := p.Data
offset := 0 offset := 0
for len(data) > (32768 - 64) { for len(data) > int(l.me.mtu) {
packet := *p packet := *p
packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)] packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)]
i, err := l.write(&packet, istransfer) i, err := l.write(&packet, uint16(offset), istransfer, true)
n += i n += i
if err != nil { if err != nil {
return n, err return n, err
} }
data = data[(offset+1)*(32768-64):] data = data[(offset+1)*int(l.me.mtu):]
} }
return n, nil return n, nil
} }
@@ -107,14 +107,17 @@ func (l *Link) String() (n string) {
} }
// write 向 peer 发一个包 // write 向 peer 发一个包
func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) { func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) {
var d []byte var d []byte
if istransfer { if istransfer {
d = p.Marshal(nil) if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
}
d = p.Marshal(nil, 0, false, false)
} else { } else {
p.FillHash() p.FillHash()
p.Data = l.Encode(p.Data) p.Data = l.Encode(p.Data)
d = p.Marshal(l.me.me) d = p.Marshal(l.me.me, offset, false, hasmore)
} }
if d == nil { if d == nil {
return 0, errors.New("[link] ttl exceeded") return 0, errors.New("[link] ttl exceeded")

View File

@@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
n, addr, err := conn.ReadFromUDP(lbf) n, addr, err := conn.ReadFromUDP(lbf)
if err == nil { if err == nil {
lbf = lbf[:n] lbf = lbf[:n]
packet := head.Packet{} packet := m.wait(lbf)
err = packet.Unmarshal(lbf) if packet != nil {
if err == nil {
r := int(packet.DataSZ) - len(packet.Data) r := int(packet.DataSZ) - len(packet.Data)
if r > 0 { if r > 0 {
remain, err := readAll(conn, r) remain, err := readAll(conn, r)
@@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
} }
case head.ProtoNotify: case head.ProtoNotify:
logrus.Infoln("[link] recv notify") logrus.Infoln("[link] recv notify")
p.onNotify(&packet) p.onNotify(packet)
case head.ProtoQuery: case head.ProtoQuery:
logrus.Infoln("[link] recv query") logrus.Infoln("[link] recv query")
p.onQuery(&packet) p.onQuery(packet)
case head.ProtoData: case head.ProtoData:
if p.pipe != nil { if p.pipe != nil {
p.pipe <- &packet p.pipe <- packet
logrus.Infoln("[link] deliver to pipe of", p.peerip) logrus.Infoln("[link] deliver to pipe of", p.peerip)
} else { } else {
m.pipe <- &packet m.pipe <- packet
logrus.Infoln("[link] deliver to pipe of me") logrus.Infoln("[link] deliver to pipe of me")
} }
default: default:
@@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
} else if p.Accept(packet.Dst) { } else if p.Accept(packet.Dst) {
if p.allowtrans { if p.allowtrans {
// 转发 // 转发
n, err = p.Write(&packet, true) n, err = p.Write(packet, true)
if err == nil { if err == nil {
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else { } else {

View File

@@ -32,10 +32,17 @@ type Me struct {
pipe chan *head.Packet pipe chan *head.Packet
// 本机路由表 // 本机路由表
router *Router router *Router
// 本机未接收完全分片池
recving map[[32]byte]*head.Packet
recvmu sync.Mutex
// 超时定时器
clock map[*head.Packet]uint8
// 本机上层配置
srcport, dstport, mtu uint16
} }
// NewMe 设置本机参数 // NewMe 设置本机参数
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) { func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) {
m.privKey = *privateKey m.privKey = *privateKey
var err error var err error
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint) m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
@@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei
} }
m.router.SetDefault(nil) m.router.SetDefault(nil)
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink) m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
m.srcport = srcport
m.dstport = dstport
m.mtu = mtu
return return
} }
func (m *Me) SrcPort() uint16 {
return m.srcport
}
func (m *Me) DstPort() uint16 {
return m.dstport
}
func (m *Me) MTU() uint16 {
return m.mtu
}

77
gold/link/recv.go Normal file
View File

@@ -0,0 +1,77 @@
package link
import (
"encoding/binary"
"time"
"unsafe"
"github.com/fumiama/WireGold/gold/head"
"github.com/sirupsen/logrus"
)
func (m *Me) initrecvpool() {
if m.recving == nil {
m.recving = make(map[[32]byte]*head.Packet, 128)
}
// 超时定时器
m.clock = make(map[*head.Packet]uint8, 128)
var delhs []*head.Packet
t := time.NewTicker(time.Second)
for range t.C {
m.recvmu.Lock()
for k, v := range m.clock {
if v > 10 { // 10s
delete(m.recving, k.Hash)
delhs = append(delhs, k)
} else {
m.clock[k]++
}
}
for _, k := range delhs {
delete(m.clock, k)
logrus.Warnln("[recv] drop timeout packet from", k.Src)
}
delhs = delhs[:0]
m.recvmu.Unlock()
}
}
func (m *Me) wait(data []byte) *head.Packet {
flags := binary.LittleEndian.Uint16(data[10:12])
if flags == 0 || flags == 0x4000 {
h := &head.Packet{}
_, err := h.Unmarshal(data)
if err != nil {
logrus.Errorln("[recv] unmarshal err:", err)
return nil
}
return h
}
m.recvmu.Lock()
defer m.recvmu.Unlock()
hashd := data[20:52]
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd)))
h, ok := m.recving[hsh]
if ok {
ok, err := h.Unmarshal(data)
if err == nil {
if ok {
return h
}
m.clock[h] = 0
} else {
logrus.Errorln("[recv] unmarshal err:", err)
}
return nil
}
h = &head.Packet{}
_, err := h.Unmarshal(data)
if err != nil {
logrus.Errorln("[recv] unmarshal err:", err)
return nil
}
m.recving[hsh] = h
m.clock[h] = 0
return nil
}

View File

@@ -58,8 +58,8 @@ func (nc *NIC) Start(m *link.Me) {
logrus.Infoln("[lower] recv write", n, "bytes packet to nic") logrus.Infoln("[lower] recv write", n, "bytes packet to nic")
} }
}() }()
buf := make([]byte, 65536) // 永远不可能超界 buf := make([]byte, m.MTU()+64) // 增加报头长度与 TEA 冗余
for nc.hasstart { // 从 NIC 发送 for nc.hasstart { // 从 NIC 发送
packet := buf packet := buf
n, err := nc.ifce.Read(packet) n, err := nc.ifce.Read(packet)
if err != nil { if err != nil {
@@ -115,15 +115,13 @@ func send(m *link.Me, packet []byte) (n int, rem []byte) {
packet = packet[:totl] packet = packet[:totl]
n = int(totl) n = int(totl)
dst := waterutil.IPv4Destination(packet) dst := waterutil.IPv4Destination(packet)
srcport := waterutil.IPv4SourcePort(packet) logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())))
dstport := waterutil.IPv4DestinationPort(packet)
logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(srcport)), "to", dst.String()+":"+strconv.Itoa(int(dstport)))
lnk, err := m.Connect(dst.String()) lnk, err := m.Connect(dst.String())
if err != nil { if err != nil {
logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err) logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err)
return return
} }
_, err = lnk.Write(head.NewPacket(head.ProtoData, srcport, dst, dstport, packet), false) _, err = lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), dst, m.DstPort(), packet), false)
if err != nil { if err != nil {
logrus.Warnln("[lower] write to peer", dst.String(), "err:", err) logrus.Warnln("[lower] write to peer", dst.String(), "err:", err)
} }

91
main.go
View File

@@ -1,22 +1,20 @@
package main package main
import ( import (
"bytes"
"flag" "flag"
"fmt" "fmt"
"net"
"os" "os"
base14 "github.com/fumiama/go-base16384" base14 "github.com/fumiama/go-base16384"
curve "github.com/fumiama/go-x25519" curve "github.com/fumiama/go-x25519"
"github.com/fumiama/WireGold/config" "github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/gold/link"
"github.com/fumiama/WireGold/helper" "github.com/fumiama/WireGold/helper"
"github.com/fumiama/WireGold/lower" "github.com/fumiama/WireGold/upper"
"github.com/fumiama/WireGold/upper/services/wg"
) )
const suffix32 = "㴄"
func main() { func main() {
help := flag.Bool("h", false, "display this help") help := flag.Bool("h", false, "display this help")
gen := flag.Bool("g", false, "generate key pair") gen := flag.Bool("g", false, "generate key pair")
@@ -44,16 +42,11 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if helper.IsNotExist(*file) { if helper.IsNotExist(*file) {
f, err := os.Create(*file) f := new(bytes.Buffer)
if err != nil {
panic(err)
}
var r string var r string
fmt.Print("IP: ") fmt.Print("IP: ")
fmt.Scanln(&r) fmt.Scanln(&r)
if r == "" { if r == "" {
f.Close()
os.Remove(*file)
fmt.Println("nil ip") fmt.Println("nil ip")
return return
} }
@@ -63,8 +56,6 @@ func main() {
fmt.Print("SubNet: ") fmt.Print("SubNet: ")
fmt.Scanln(&r) fmt.Scanln(&r)
if r == "" { if r == "" {
f.Close()
os.Remove(*file)
fmt.Println("nil subnet") fmt.Println("nil subnet")
return return
} }
@@ -74,8 +65,6 @@ func main() {
fmt.Print("PrivateKey: ") fmt.Print("PrivateKey: ")
fmt.Scanln(&r) fmt.Scanln(&r)
if r == "" { if r == "" {
f.Close()
os.Remove(*file)
fmt.Println("nil private key") fmt.Println("nil private key")
return return
} }
@@ -85,15 +74,18 @@ func main() {
fmt.Print("EndPoint: ") fmt.Print("EndPoint: ")
fmt.Scanln(&r) fmt.Scanln(&r)
if r == "" { if r == "" {
f.Close()
os.Remove(*file)
fmt.Println("nil endpoint") fmt.Println("nil endpoint")
return return
} }
f.WriteString("EndPoint: " + r + "\n") f.WriteString("EndPoint: " + r + "\n")
r = "" r = ""
f.Close() cfgf, err := os.Create(*file)
if err != nil {
panic(err)
}
cfgf.Write(f.Bytes())
cfgf.Close()
} }
c := config.Parse(*file) c := config.Parse(*file)
if c.IP == "" { if c.IP == "" {
@@ -108,73 +100,18 @@ func main() {
if c.EndPoint == "" { if c.EndPoint == "" {
displayHelp("nil endpoint") displayHelp("nil endpoint")
} }
var key [32]byte w, err := wg.NewWireGold(&c)
k, err := base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
if err != nil { if err != nil {
panic(err) panic(err)
} }
n := copy(key[:], base14.Decode(k))
if n != 32 {
displayHelp("private key length is not 32")
}
if *showp { if *showp {
c := curve.Get(key[:]) fmt.Println("PublicKey:", w.PublicKey)
pubk, err := base14.UTF16be2utf8(base14.Encode((*c.Public())[:]))
if err != nil {
panic(err)
}
fmt.Println("PublicKey:", helper.BytesToString(pubk[:57]))
os.Exit(0) os.Exit(0)
} }
cidrsmap := make(map[string]bool, 32) defer w.Stop()
_, mysubnet, err := net.ParseCIDR(c.SubNet) w.Run(upper.ServiceWireGold, upper.ServiceWireGold, 32768-64)
if err != nil {
panic(err)
}
for _, p := range c.Peers {
for _, ip := range p.AllowedIPs {
ipnet, _, err := net.ParseCIDR(ip)
if err != nil {
panic(err)
}
if !mysubnet.Contains(ipnet) {
cidrsmap[ip] = true
}
}
}
cidrs := make([]string, len(cidrsmap))
i := 0
for k := range cidrsmap {
cidrs[i] = k
i++
}
nic := lower.NewNIC(c.IP, c.SubNet, cidrs...)
me := link.NewMe(&key, c.IP+"/32", c.EndPoint, true)
for _, peer := range c.Peers {
var peerkey [32]byte
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
if err != nil {
panic(err)
}
n := copy(peerkey[:], base14.Decode(k))
if n != 32 {
panic("peer public key length is not 32")
}
me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
}
nic.Up()
defer func() {
nic.Stop()
nic.Down()
nic.Destroy()
}()
nic.Start(&me)
} }
func displayHelp(hint string) { func displayHelp(hint string) {

View File

@@ -8,9 +8,16 @@ const (
ServiceNull = iota ServiceNull = iota
// ServiceTunnel 管道通信服务 // ServiceTunnel 管道通信服务
ServiceTunnel ServiceTunnel
// ServiceWireGold 虚拟组网服务
ServiceWireGold
) )
type Service interface { type Service interface {
Create(peer string, srcport, destport, mtu uint16) (Service, error) // Start 无阻塞运行
io.ReadWriteCloser Start(srcport, destport, mtu uint16)
// Run 阻塞运行
Run(srcport, destport, mtu uint16)
// Stop 停止
Stop()
io.ReadWriter
} }

View File

@@ -21,24 +21,36 @@ type Tunnel struct {
mtu uint16 mtu uint16
} }
func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, err error) { func Create(me *link.Me, peer string) (s Tunnel, err error) {
logrus.Infoln("[tunnel] create from", srcport, "to", destport)
s.l, err = me.Connect(peer) s.l, err = me.Connect(peer)
if err == nil { if err == nil {
s.in = make(chan []byte, 4) s.in = make(chan []byte, 4)
s.out = make(chan []byte, 4) s.out = make(chan []byte, 4)
s.peerip = net.ParseIP(peer) s.peerip = net.ParseIP(peer)
s.src = srcport
s.dest = destport
s.mtu = mtu
go s.handleWrite()
go s.handleRead()
} else { } else {
logrus.Errorln("[tunnel] create err:", err) logrus.Errorln("[tunnel] create err:", err)
} }
return return
} }
func (s *Tunnel) Start(srcport, destport, mtu uint16) {
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
s.src = srcport
s.dest = destport
s.mtu = mtu
go s.handleWrite()
go s.handleRead()
}
func (s *Tunnel) Run(srcport, destport, mtu uint16) {
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
s.src = srcport
s.dest = destport
s.mtu = mtu
go s.handleWrite()
s.handleRead()
}
func (s *Tunnel) Write(p []byte) (int, error) { func (s *Tunnel) Write(p []byte) (int, error) {
s.in <- p s.in <- p
return len(p), nil return len(p), nil
@@ -63,10 +75,9 @@ func (s *Tunnel) Read(p []byte) (int, error) {
return 0, errors.New("reading reaches nil") return 0, errors.New("reading reaches nil")
} }
func (s *Tunnel) Close() error { func (s *Tunnel) Stop() {
s.l.Close() s.l.Close()
close(s.in) close(s.in)
return nil
} }
func (s *Tunnel) handleWrite() { func (s *Tunnel) handleWrite() {

View File

@@ -27,18 +27,20 @@ func TestTunnel(t *testing.T) {
t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:])) t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:]))
t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:])) t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:]))
m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false) m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false, 1, 1, 4096)
m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false) m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false)
p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false) p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false, 1, 1, 4096)
p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false) p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false)
tunnme, err := Create(&m, "192.168.1.3", 1, 1, 4096) tunnme, err := Create(&m, "192.168.1.3")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tunnpeer, err := Create(&p, "192.168.1.2", 1, 1, 4096) tunnme.Start(1, 1, 4096)
tunnpeer, err := Create(&p, "192.168.1.2")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tunnpeer.Start(1, 1, 4096)
sendb := ([]byte)("1234") sendb := ([]byte)("1234")
tunnme.Write(sendb) tunnme.Write(sendb)
@@ -68,4 +70,7 @@ func TestTunnel(t *testing.T) {
if string(sendb) != string(buf) { if string(sendb) != string(buf) {
t.Fatal("error: recv 131072 bytes data") t.Fatal("error: recv 131072 bytes data")
} }
tunnme.Stop()
tunnpeer.Stop()
} }

107
upper/services/wg/wg.go Normal file
View File

@@ -0,0 +1,107 @@
package wg
import (
"errors"
"net"
base14 "github.com/fumiama/go-base16384"
curve "github.com/fumiama/go-x25519"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/gold/link"
"github.com/fumiama/WireGold/helper"
"github.com/fumiama/WireGold/lower"
)
const suffix32 = "㴄"
type WG struct {
c *config.Config
key [32]byte
PublicKey string
nic *lower.NIC
me link.Me
}
func NewWireGold(c *config.Config) (wg WG, err error) {
wg.c = c
var k []byte
k, err = base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
if err != nil {
return
}
n := copy(wg.key[:], base14.Decode(k))
if n != 32 {
err = errors.New("private key length is not 32")
return
}
cur := curve.Get(wg.key[:])
pubk, err := base14.UTF16be2utf8(base14.Encode((*cur.Public())[:]))
if err != nil {
return
}
wg.PublicKey = helper.BytesToString(pubk[:57])
return
}
func (wg *WG) Start(srcport, destport, mtu uint16) {
wg.init(srcport, destport, mtu)
wg.nic.Up()
go wg.nic.Start(&wg.me)
}
func (wg *WG) Run(srcport, destport, mtu uint16) {
wg.init(srcport, destport, mtu)
wg.nic.Up()
wg.nic.Start(&wg.me)
}
func (wg *WG) Stop() {
wg.nic.Stop()
wg.nic.Down()
wg.nic.Destroy()
}
func (wg *WG) init(srcport, destport, mtu uint16) {
cidrsmap := make(map[string]bool, 32)
_, mysubnet, err := net.ParseCIDR(wg.c.SubNet)
if err != nil {
panic(err)
}
for _, p := range wg.c.Peers {
for _, ip := range p.AllowedIPs {
ipnet, _, err := net.ParseCIDR(ip)
if err != nil {
panic(err)
}
if !mysubnet.Contains(ipnet) {
cidrsmap[ip] = true
}
}
}
cidrs := make([]string, len(cidrsmap))
i := 0
for k := range cidrsmap {
cidrs[i] = k
i++
}
wg.nic = lower.NewNIC(wg.c.IP, wg.c.SubNet, cidrs...)
wg.me = link.NewMe(&wg.key, wg.c.IP+"/32", wg.c.EndPoint, true, srcport, destport, mtu)
for _, peer := range wg.c.Peers {
var peerkey [32]byte
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
if err != nil {
panic(err)
}
n := copy(peerkey[:], base14.Decode(k))
if n != 32 {
panic("peer public key length is not 32")
}
wg.me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
}
}