mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-04 23:40:26 +08:00
add 分片
This commit is contained in:
@@ -12,9 +12,11 @@ import (
|
||||
|
||||
// Packet 是发送和接收的最小单位
|
||||
type Packet struct {
|
||||
// Ver 协议版本
|
||||
Ver uint16
|
||||
// DataSZ len(Data)
|
||||
// 不得超过 65507-head 字节
|
||||
DataSZ uint32
|
||||
DataSZ uint16
|
||||
// Proto 详见 head
|
||||
Proto uint8
|
||||
// TTL is time to live
|
||||
@@ -23,9 +25,11 @@ type Packet struct {
|
||||
SrcPort uint16
|
||||
// DstPort 目的端口
|
||||
DstPort uint16
|
||||
// Src 源 ip
|
||||
// Flags 高3位为标志(xDM),低13位为分片偏移
|
||||
Flags uint16
|
||||
// Src 源 ip (ipv4)
|
||||
Src net.IP
|
||||
// Dst 目的 ip
|
||||
// Dst 目的 ip (ipv4)
|
||||
Dst net.IP
|
||||
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
|
||||
// 生成时 Hash 全 0
|
||||
@@ -33,12 +37,15 @@ type Packet struct {
|
||||
Hash [32]byte
|
||||
// Data 承载的数据
|
||||
Data []byte
|
||||
// 记录还有多少字节未到达
|
||||
rembytes uint16
|
||||
}
|
||||
|
||||
// NewPacket 生成一个新包
|
||||
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
|
||||
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
|
||||
return &Packet{
|
||||
Ver: 1,
|
||||
Proto: proto,
|
||||
TTL: 16,
|
||||
SrcPort: srcPort,
|
||||
@@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
|
||||
}
|
||||
|
||||
// Unmarshal 将 data 的数据解码到自身
|
||||
func (p *Packet) Unmarshal(data []byte) error {
|
||||
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||
if len(data) < 12 {
|
||||
return errors.New("data len < 12")
|
||||
err = errors.New("data len < 12")
|
||||
return
|
||||
}
|
||||
p.DataSZ = binary.LittleEndian.Uint32(data[:4])
|
||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||
p.Proto = uint8(pt)
|
||||
p.TTL = uint8(pt >> 8)
|
||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
sdl := binary.LittleEndian.Uint16(data[10:12])
|
||||
srclen := uint8(sdl)
|
||||
dstlen := uint8(sdl >> 8)
|
||||
if len(data) < int(12+srclen+dstlen) {
|
||||
return errors.New("data src or dst len mismatch")
|
||||
if p.DataSZ == 0 && len(p.Data) == 0 {
|
||||
p.Ver = binary.LittleEndian.Uint16(data[:2])
|
||||
if p.Ver != 1 {
|
||||
err = errors.New("unknown protocol version")
|
||||
return
|
||||
}
|
||||
p.DataSZ = binary.LittleEndian.Uint16(data[2:4])
|
||||
p.Data = make([]byte, p.DataSZ)
|
||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||
p.Proto = uint8(pt)
|
||||
p.TTL = uint8(pt >> 8)
|
||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||
p.rembytes = p.DataSZ
|
||||
}
|
||||
if srclen > 0 {
|
||||
p.Src = make(net.IP, srclen)
|
||||
copy(p.Src, data[12:12+srclen])
|
||||
}
|
||||
if dstlen > 0 {
|
||||
p.Dst = make(net.IP, dstlen)
|
||||
copy(p.Dst, data[12+srclen:12+srclen+dstlen])
|
||||
}
|
||||
copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32])
|
||||
p.Data = data[12+srclen+dstlen+32:]
|
||||
return nil
|
||||
|
||||
p.Flags = binary.LittleEndian.Uint16(data[10:12])
|
||||
|
||||
p.Src = make(net.IP, 4)
|
||||
copy(p.Src, data[12:16])
|
||||
p.Dst = make(net.IP, 4)
|
||||
copy(p.Dst, data[16:20])
|
||||
copy(p.Hash[:], data[20:52])
|
||||
p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:]))
|
||||
|
||||
complete = p.rembytes == 0
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Marshal 将自身数据编码为 []byte
|
||||
func (p *Packet) Marshal(src net.IP) []byte {
|
||||
// offset 必须为 8 的倍数,表示偏移的 8 位
|
||||
func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte {
|
||||
p.TTL--
|
||||
if p.TTL == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.DataSZ = uint32(len(p.Data))
|
||||
p.DataSZ = uint16(len(p.Data))
|
||||
if src != nil {
|
||||
p.Src = src
|
||||
offset >>= 3
|
||||
if dontfrag {
|
||||
offset |= 0x4000
|
||||
}
|
||||
if hasmore {
|
||||
offset |= 0x2000
|
||||
}
|
||||
p.Flags = offset
|
||||
}
|
||||
|
||||
packet := make([]byte, 52+len(p.Data))
|
||||
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ)
|
||||
binary.LittleEndian.PutUint16(packet[:2], p.Ver)
|
||||
binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ)
|
||||
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
|
||||
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
|
||||
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
|
||||
binary.LittleEndian.PutUint16(packet[10:12], 0x0404)
|
||||
binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
|
||||
copy(packet[12:16], p.Src.To4())
|
||||
copy(packet[16:20], p.Dst.To4())
|
||||
copy(packet[20:52], p.Hash[:])
|
||||
|
||||
@@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet {
|
||||
|
||||
// Write 向 peer 发包
|
||||
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
if len(p.Data) <= (32768 - 64) {
|
||||
return l.write(p, istransfer)
|
||||
if len(p.Data) <= int(l.me.mtu) {
|
||||
return l.write(p, 0, istransfer, false)
|
||||
}
|
||||
data := p.Data
|
||||
offset := 0
|
||||
for len(data) > (32768 - 64) {
|
||||
for len(data) > int(l.me.mtu) {
|
||||
packet := *p
|
||||
packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)]
|
||||
i, err := l.write(&packet, istransfer)
|
||||
packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)]
|
||||
i, err := l.write(&packet, uint16(offset), istransfer, true)
|
||||
n += i
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
data = data[(offset+1)*(32768-64):]
|
||||
data = data[(offset+1)*int(l.me.mtu):]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@@ -107,14 +107,17 @@ func (l *Link) String() (n string) {
|
||||
}
|
||||
|
||||
// write 向 peer 发一个包
|
||||
func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) {
|
||||
func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) {
|
||||
var d []byte
|
||||
if istransfer {
|
||||
d = p.Marshal(nil)
|
||||
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
|
||||
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
|
||||
}
|
||||
d = p.Marshal(nil, 0, false, false)
|
||||
} else {
|
||||
p.FillHash()
|
||||
p.Data = l.Encode(p.Data)
|
||||
d = p.Marshal(l.me.me)
|
||||
d = p.Marshal(l.me.me, offset, false, hasmore)
|
||||
}
|
||||
if d == nil {
|
||||
return 0, errors.New("[link] ttl exceeded")
|
||||
|
||||
@@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
n, addr, err := conn.ReadFromUDP(lbf)
|
||||
if err == nil {
|
||||
lbf = lbf[:n]
|
||||
packet := head.Packet{}
|
||||
err = packet.Unmarshal(lbf)
|
||||
if err == nil {
|
||||
packet := m.wait(lbf)
|
||||
if packet != nil {
|
||||
r := int(packet.DataSZ) - len(packet.Data)
|
||||
if r > 0 {
|
||||
remain, err := readAll(conn, r)
|
||||
@@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
}
|
||||
case head.ProtoNotify:
|
||||
logrus.Infoln("[link] recv notify")
|
||||
p.onNotify(&packet)
|
||||
p.onNotify(packet)
|
||||
case head.ProtoQuery:
|
||||
logrus.Infoln("[link] recv query")
|
||||
p.onQuery(&packet)
|
||||
p.onQuery(packet)
|
||||
case head.ProtoData:
|
||||
if p.pipe != nil {
|
||||
p.pipe <- &packet
|
||||
p.pipe <- packet
|
||||
logrus.Infoln("[link] deliver to pipe of", p.peerip)
|
||||
} else {
|
||||
m.pipe <- &packet
|
||||
m.pipe <- packet
|
||||
logrus.Infoln("[link] deliver to pipe of me")
|
||||
}
|
||||
default:
|
||||
@@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
||||
} else if p.Accept(packet.Dst) {
|
||||
if p.allowtrans {
|
||||
// 转发
|
||||
n, err = p.Write(&packet, true)
|
||||
n, err = p.Write(packet, true)
|
||||
if err == nil {
|
||||
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
|
||||
} else {
|
||||
|
||||
@@ -32,10 +32,17 @@ type Me struct {
|
||||
pipe chan *head.Packet
|
||||
// 本机路由表
|
||||
router *Router
|
||||
// 本机未接收完全分片池
|
||||
recving map[[32]byte]*head.Packet
|
||||
recvmu sync.Mutex
|
||||
// 超时定时器
|
||||
clock map[*head.Packet]uint8
|
||||
// 本机上层配置
|
||||
srcport, dstport, mtu uint16
|
||||
}
|
||||
|
||||
// NewMe 设置本机参数
|
||||
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) {
|
||||
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) {
|
||||
m.privKey = *privateKey
|
||||
var err error
|
||||
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
|
||||
@@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei
|
||||
}
|
||||
m.router.SetDefault(nil)
|
||||
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
|
||||
m.srcport = srcport
|
||||
m.dstport = dstport
|
||||
m.mtu = mtu
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Me) SrcPort() uint16 {
|
||||
return m.srcport
|
||||
}
|
||||
|
||||
func (m *Me) DstPort() uint16 {
|
||||
return m.dstport
|
||||
}
|
||||
|
||||
func (m *Me) MTU() uint16 {
|
||||
return m.mtu
|
||||
}
|
||||
|
||||
77
gold/link/recv.go
Normal file
77
gold/link/recv.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package link
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/WireGold/gold/head"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (m *Me) initrecvpool() {
|
||||
if m.recving == nil {
|
||||
m.recving = make(map[[32]byte]*head.Packet, 128)
|
||||
}
|
||||
// 超时定时器
|
||||
m.clock = make(map[*head.Packet]uint8, 128)
|
||||
var delhs []*head.Packet
|
||||
t := time.NewTicker(time.Second)
|
||||
for range t.C {
|
||||
m.recvmu.Lock()
|
||||
for k, v := range m.clock {
|
||||
if v > 10 { // 10s
|
||||
delete(m.recving, k.Hash)
|
||||
delhs = append(delhs, k)
|
||||
} else {
|
||||
m.clock[k]++
|
||||
}
|
||||
}
|
||||
for _, k := range delhs {
|
||||
delete(m.clock, k)
|
||||
logrus.Warnln("[recv] drop timeout packet from", k.Src)
|
||||
}
|
||||
delhs = delhs[:0]
|
||||
m.recvmu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Me) wait(data []byte) *head.Packet {
|
||||
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||
if flags == 0 || flags == 0x4000 {
|
||||
h := &head.Packet{}
|
||||
_, err := h.Unmarshal(data)
|
||||
if err != nil {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
return nil
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
m.recvmu.Lock()
|
||||
defer m.recvmu.Unlock()
|
||||
hashd := data[20:52]
|
||||
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd)))
|
||||
h, ok := m.recving[hsh]
|
||||
if ok {
|
||||
ok, err := h.Unmarshal(data)
|
||||
if err == nil {
|
||||
if ok {
|
||||
return h
|
||||
}
|
||||
m.clock[h] = 0
|
||||
} else {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
h = &head.Packet{}
|
||||
_, err := h.Unmarshal(data)
|
||||
if err != nil {
|
||||
logrus.Errorln("[recv] unmarshal err:", err)
|
||||
return nil
|
||||
}
|
||||
m.recving[hsh] = h
|
||||
m.clock[h] = 0
|
||||
return nil
|
||||
}
|
||||
10
lower/nic.go
10
lower/nic.go
@@ -58,8 +58,8 @@ func (nc *NIC) Start(m *link.Me) {
|
||||
logrus.Infoln("[lower] recv write", n, "bytes packet to nic")
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 65536) // 永远不可能超界
|
||||
for nc.hasstart { // 从 NIC 发送
|
||||
buf := make([]byte, m.MTU()+64) // 增加报头长度与 TEA 冗余
|
||||
for nc.hasstart { // 从 NIC 发送
|
||||
packet := buf
|
||||
n, err := nc.ifce.Read(packet)
|
||||
if err != nil {
|
||||
@@ -115,15 +115,13 @@ func send(m *link.Me, packet []byte) (n int, rem []byte) {
|
||||
packet = packet[:totl]
|
||||
n = int(totl)
|
||||
dst := waterutil.IPv4Destination(packet)
|
||||
srcport := waterutil.IPv4SourcePort(packet)
|
||||
dstport := waterutil.IPv4DestinationPort(packet)
|
||||
logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(srcport)), "to", dst.String()+":"+strconv.Itoa(int(dstport)))
|
||||
logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())))
|
||||
lnk, err := m.Connect(dst.String())
|
||||
if err != nil {
|
||||
logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err)
|
||||
return
|
||||
}
|
||||
_, err = lnk.Write(head.NewPacket(head.ProtoData, srcport, dst, dstport, packet), false)
|
||||
_, err = lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), dst, m.DstPort(), packet), false)
|
||||
if err != nil {
|
||||
logrus.Warnln("[lower] write to peer", dst.String(), "err:", err)
|
||||
}
|
||||
|
||||
91
main.go
91
main.go
@@ -1,22 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
base14 "github.com/fumiama/go-base16384"
|
||||
curve "github.com/fumiama/go-x25519"
|
||||
|
||||
"github.com/fumiama/WireGold/config"
|
||||
"github.com/fumiama/WireGold/gold/link"
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/fumiama/WireGold/lower"
|
||||
"github.com/fumiama/WireGold/upper"
|
||||
"github.com/fumiama/WireGold/upper/services/wg"
|
||||
)
|
||||
|
||||
const suffix32 = "㴄"
|
||||
|
||||
func main() {
|
||||
help := flag.Bool("h", false, "display this help")
|
||||
gen := flag.Bool("g", false, "generate key pair")
|
||||
@@ -44,16 +42,11 @@ func main() {
|
||||
os.Exit(0)
|
||||
}
|
||||
if helper.IsNotExist(*file) {
|
||||
f, err := os.Create(*file)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f := new(bytes.Buffer)
|
||||
var r string
|
||||
fmt.Print("IP: ")
|
||||
fmt.Scanln(&r)
|
||||
if r == "" {
|
||||
f.Close()
|
||||
os.Remove(*file)
|
||||
fmt.Println("nil ip")
|
||||
return
|
||||
}
|
||||
@@ -63,8 +56,6 @@ func main() {
|
||||
fmt.Print("SubNet: ")
|
||||
fmt.Scanln(&r)
|
||||
if r == "" {
|
||||
f.Close()
|
||||
os.Remove(*file)
|
||||
fmt.Println("nil subnet")
|
||||
return
|
||||
}
|
||||
@@ -74,8 +65,6 @@ func main() {
|
||||
fmt.Print("PrivateKey: ")
|
||||
fmt.Scanln(&r)
|
||||
if r == "" {
|
||||
f.Close()
|
||||
os.Remove(*file)
|
||||
fmt.Println("nil private key")
|
||||
return
|
||||
}
|
||||
@@ -85,15 +74,18 @@ func main() {
|
||||
fmt.Print("EndPoint: ")
|
||||
fmt.Scanln(&r)
|
||||
if r == "" {
|
||||
f.Close()
|
||||
os.Remove(*file)
|
||||
fmt.Println("nil endpoint")
|
||||
return
|
||||
}
|
||||
f.WriteString("EndPoint: " + r + "\n")
|
||||
r = ""
|
||||
|
||||
f.Close()
|
||||
cfgf, err := os.Create(*file)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cfgf.Write(f.Bytes())
|
||||
cfgf.Close()
|
||||
}
|
||||
c := config.Parse(*file)
|
||||
if c.IP == "" {
|
||||
@@ -108,73 +100,18 @@ func main() {
|
||||
if c.EndPoint == "" {
|
||||
displayHelp("nil endpoint")
|
||||
}
|
||||
var key [32]byte
|
||||
k, err := base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
|
||||
w, err := wg.NewWireGold(&c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
n := copy(key[:], base14.Decode(k))
|
||||
if n != 32 {
|
||||
displayHelp("private key length is not 32")
|
||||
}
|
||||
|
||||
if *showp {
|
||||
c := curve.Get(key[:])
|
||||
pubk, err := base14.UTF16be2utf8(base14.Encode((*c.Public())[:]))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("PublicKey:", helper.BytesToString(pubk[:57]))
|
||||
fmt.Println("PublicKey:", w.PublicKey)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
cidrsmap := make(map[string]bool, 32)
|
||||
_, mysubnet, err := net.ParseCIDR(c.SubNet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, p := range c.Peers {
|
||||
for _, ip := range p.AllowedIPs {
|
||||
ipnet, _, err := net.ParseCIDR(ip)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if !mysubnet.Contains(ipnet) {
|
||||
cidrsmap[ip] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
cidrs := make([]string, len(cidrsmap))
|
||||
i := 0
|
||||
for k := range cidrsmap {
|
||||
cidrs[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
nic := lower.NewNIC(c.IP, c.SubNet, cidrs...)
|
||||
me := link.NewMe(&key, c.IP+"/32", c.EndPoint, true)
|
||||
|
||||
for _, peer := range c.Peers {
|
||||
var peerkey [32]byte
|
||||
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
n := copy(peerkey[:], base14.Decode(k))
|
||||
if n != 32 {
|
||||
panic("peer public key length is not 32")
|
||||
}
|
||||
me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
|
||||
}
|
||||
|
||||
nic.Up()
|
||||
defer func() {
|
||||
nic.Stop()
|
||||
nic.Down()
|
||||
nic.Destroy()
|
||||
}()
|
||||
|
||||
nic.Start(&me)
|
||||
defer w.Stop()
|
||||
w.Run(upper.ServiceWireGold, upper.ServiceWireGold, 32768-64)
|
||||
}
|
||||
|
||||
func displayHelp(hint string) {
|
||||
|
||||
@@ -8,9 +8,16 @@ const (
|
||||
ServiceNull = iota
|
||||
// ServiceTunnel 管道通信服务
|
||||
ServiceTunnel
|
||||
// ServiceWireGold 虚拟组网服务
|
||||
ServiceWireGold
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Create(peer string, srcport, destport, mtu uint16) (Service, error)
|
||||
io.ReadWriteCloser
|
||||
// Start 无阻塞运行
|
||||
Start(srcport, destport, mtu uint16)
|
||||
// Run 阻塞运行
|
||||
Run(srcport, destport, mtu uint16)
|
||||
// Stop 停止
|
||||
Stop()
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
@@ -21,24 +21,36 @@ type Tunnel struct {
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, err error) {
|
||||
logrus.Infoln("[tunnel] create from", srcport, "to", destport)
|
||||
func Create(me *link.Me, peer string) (s Tunnel, err error) {
|
||||
s.l, err = me.Connect(peer)
|
||||
if err == nil {
|
||||
s.in = make(chan []byte, 4)
|
||||
s.out = make(chan []byte, 4)
|
||||
s.peerip = net.ParseIP(peer)
|
||||
s.src = srcport
|
||||
s.dest = destport
|
||||
s.mtu = mtu
|
||||
go s.handleWrite()
|
||||
go s.handleRead()
|
||||
} else {
|
||||
logrus.Errorln("[tunnel] create err:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Tunnel) Start(srcport, destport, mtu uint16) {
|
||||
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
|
||||
s.src = srcport
|
||||
s.dest = destport
|
||||
s.mtu = mtu
|
||||
go s.handleWrite()
|
||||
go s.handleRead()
|
||||
}
|
||||
|
||||
func (s *Tunnel) Run(srcport, destport, mtu uint16) {
|
||||
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
|
||||
s.src = srcport
|
||||
s.dest = destport
|
||||
s.mtu = mtu
|
||||
go s.handleWrite()
|
||||
s.handleRead()
|
||||
}
|
||||
|
||||
func (s *Tunnel) Write(p []byte) (int, error) {
|
||||
s.in <- p
|
||||
return len(p), nil
|
||||
@@ -63,10 +75,9 @@ func (s *Tunnel) Read(p []byte) (int, error) {
|
||||
return 0, errors.New("reading reaches nil")
|
||||
}
|
||||
|
||||
func (s *Tunnel) Close() error {
|
||||
func (s *Tunnel) Stop() {
|
||||
s.l.Close()
|
||||
close(s.in)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Tunnel) handleWrite() {
|
||||
|
||||
@@ -27,18 +27,20 @@ func TestTunnel(t *testing.T) {
|
||||
t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:]))
|
||||
t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:]))
|
||||
|
||||
m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false)
|
||||
m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false, 1, 1, 4096)
|
||||
m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false)
|
||||
p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false)
|
||||
p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false, 1, 1, 4096)
|
||||
p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false)
|
||||
tunnme, err := Create(&m, "192.168.1.3", 1, 1, 4096)
|
||||
tunnme, err := Create(&m, "192.168.1.3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tunnpeer, err := Create(&p, "192.168.1.2", 1, 1, 4096)
|
||||
tunnme.Start(1, 1, 4096)
|
||||
tunnpeer, err := Create(&p, "192.168.1.2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tunnpeer.Start(1, 1, 4096)
|
||||
|
||||
sendb := ([]byte)("1234")
|
||||
tunnme.Write(sendb)
|
||||
@@ -68,4 +70,7 @@ func TestTunnel(t *testing.T) {
|
||||
if string(sendb) != string(buf) {
|
||||
t.Fatal("error: recv 131072 bytes data")
|
||||
}
|
||||
|
||||
tunnme.Stop()
|
||||
tunnpeer.Stop()
|
||||
}
|
||||
|
||||
107
upper/services/wg/wg.go
Normal file
107
upper/services/wg/wg.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
base14 "github.com/fumiama/go-base16384"
|
||||
curve "github.com/fumiama/go-x25519"
|
||||
|
||||
"github.com/fumiama/WireGold/config"
|
||||
"github.com/fumiama/WireGold/gold/link"
|
||||
"github.com/fumiama/WireGold/helper"
|
||||
"github.com/fumiama/WireGold/lower"
|
||||
)
|
||||
|
||||
const suffix32 = "㴄"
|
||||
|
||||
type WG struct {
|
||||
c *config.Config
|
||||
key [32]byte
|
||||
PublicKey string
|
||||
nic *lower.NIC
|
||||
me link.Me
|
||||
}
|
||||
|
||||
func NewWireGold(c *config.Config) (wg WG, err error) {
|
||||
wg.c = c
|
||||
|
||||
var k []byte
|
||||
k, err = base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n := copy(wg.key[:], base14.Decode(k))
|
||||
if n != 32 {
|
||||
err = errors.New("private key length is not 32")
|
||||
return
|
||||
}
|
||||
|
||||
cur := curve.Get(wg.key[:])
|
||||
pubk, err := base14.UTF16be2utf8(base14.Encode((*cur.Public())[:]))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
wg.PublicKey = helper.BytesToString(pubk[:57])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (wg *WG) Start(srcport, destport, mtu uint16) {
|
||||
wg.init(srcport, destport, mtu)
|
||||
wg.nic.Up()
|
||||
go wg.nic.Start(&wg.me)
|
||||
}
|
||||
|
||||
func (wg *WG) Run(srcport, destport, mtu uint16) {
|
||||
wg.init(srcport, destport, mtu)
|
||||
wg.nic.Up()
|
||||
wg.nic.Start(&wg.me)
|
||||
}
|
||||
|
||||
func (wg *WG) Stop() {
|
||||
wg.nic.Stop()
|
||||
wg.nic.Down()
|
||||
wg.nic.Destroy()
|
||||
}
|
||||
|
||||
func (wg *WG) init(srcport, destport, mtu uint16) {
|
||||
cidrsmap := make(map[string]bool, 32)
|
||||
_, mysubnet, err := net.ParseCIDR(wg.c.SubNet)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, p := range wg.c.Peers {
|
||||
for _, ip := range p.AllowedIPs {
|
||||
ipnet, _, err := net.ParseCIDR(ip)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if !mysubnet.Contains(ipnet) {
|
||||
cidrsmap[ip] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
cidrs := make([]string, len(cidrsmap))
|
||||
i := 0
|
||||
for k := range cidrsmap {
|
||||
cidrs[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
wg.nic = lower.NewNIC(wg.c.IP, wg.c.SubNet, cidrs...)
|
||||
wg.me = link.NewMe(&wg.key, wg.c.IP+"/32", wg.c.EndPoint, true, srcport, destport, mtu)
|
||||
|
||||
for _, peer := range wg.c.Peers {
|
||||
var peerkey [32]byte
|
||||
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
n := copy(peerkey[:], base14.Decode(k))
|
||||
if n != 32 {
|
||||
panic("peer public key length is not 32")
|
||||
}
|
||||
wg.me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user