mirror of
https://github.com/fumiama/WireGold.git
synced 2026-06-23 12:00:34 +08:00
add 分片
This commit is contained in:
@@ -12,9 +12,11 @@ import (
|
|||||||
|
|
||||||
// Packet 是发送和接收的最小单位
|
// Packet 是发送和接收的最小单位
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
|
// Ver 协议版本
|
||||||
|
Ver uint16
|
||||||
// DataSZ len(Data)
|
// DataSZ len(Data)
|
||||||
// 不得超过 65507-head 字节
|
// 不得超过 65507-head 字节
|
||||||
DataSZ uint32
|
DataSZ uint16
|
||||||
// Proto 详见 head
|
// Proto 详见 head
|
||||||
Proto uint8
|
Proto uint8
|
||||||
// TTL is time to live
|
// TTL is time to live
|
||||||
@@ -23,9 +25,11 @@ type Packet struct {
|
|||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
// DstPort 目的端口
|
// DstPort 目的端口
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
// Src 源 ip
|
// Flags 高3位为标志(xDM),低13位为分片偏移
|
||||||
|
Flags uint16
|
||||||
|
// Src 源 ip (ipv4)
|
||||||
Src net.IP
|
Src net.IP
|
||||||
// Dst 目的 ip
|
// Dst 目的 ip (ipv4)
|
||||||
Dst net.IP
|
Dst net.IP
|
||||||
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
|
// Hash 使用 BLAKE2 生成加密前 Packet 的摘要
|
||||||
// 生成时 Hash 全 0
|
// 生成时 Hash 全 0
|
||||||
@@ -33,12 +37,15 @@ type Packet struct {
|
|||||||
Hash [32]byte
|
Hash [32]byte
|
||||||
// Data 承载的数据
|
// Data 承载的数据
|
||||||
Data []byte
|
Data []byte
|
||||||
|
// 记录还有多少字节未到达
|
||||||
|
rembytes uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPacket 生成一个新包
|
// NewPacket 生成一个新包
|
||||||
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
|
func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []byte) *Packet {
|
||||||
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
|
logrus.Debugln("[packet] new: [proto:", proto, ", srcport:", srcPort, ", dstport:", dstPort, ", dst:", dst, ", data:", data)
|
||||||
return &Packet{
|
return &Packet{
|
||||||
|
Ver: 1,
|
||||||
Proto: proto,
|
Proto: proto,
|
||||||
TTL: 16,
|
TTL: 16,
|
||||||
SrcPort: srcPort,
|
SrcPort: srcPort,
|
||||||
@@ -49,53 +56,69 @@ func NewPacket(proto uint8, srcPort uint16, dst net.IP, dstPort uint16, data []b
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal 将 data 的数据解码到自身
|
// Unmarshal 将 data 的数据解码到自身
|
||||||
func (p *Packet) Unmarshal(data []byte) error {
|
func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
|
||||||
if len(data) < 12 {
|
if len(data) < 12 {
|
||||||
return errors.New("data len < 12")
|
err = errors.New("data len < 12")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
p.DataSZ = binary.LittleEndian.Uint32(data[:4])
|
if p.DataSZ == 0 && len(p.Data) == 0 {
|
||||||
pt := binary.LittleEndian.Uint16(data[4:6])
|
p.Ver = binary.LittleEndian.Uint16(data[:2])
|
||||||
p.Proto = uint8(pt)
|
if p.Ver != 1 {
|
||||||
p.TTL = uint8(pt >> 8)
|
err = errors.New("unknown protocol version")
|
||||||
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
return
|
||||||
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
}
|
||||||
sdl := binary.LittleEndian.Uint16(data[10:12])
|
p.DataSZ = binary.LittleEndian.Uint16(data[2:4])
|
||||||
srclen := uint8(sdl)
|
p.Data = make([]byte, p.DataSZ)
|
||||||
dstlen := uint8(sdl >> 8)
|
pt := binary.LittleEndian.Uint16(data[4:6])
|
||||||
if len(data) < int(12+srclen+dstlen) {
|
p.Proto = uint8(pt)
|
||||||
return errors.New("data src or dst len mismatch")
|
p.TTL = uint8(pt >> 8)
|
||||||
|
p.SrcPort = binary.LittleEndian.Uint16(data[6:8])
|
||||||
|
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
|
||||||
|
p.rembytes = p.DataSZ
|
||||||
}
|
}
|
||||||
if srclen > 0 {
|
|
||||||
p.Src = make(net.IP, srclen)
|
p.Flags = binary.LittleEndian.Uint16(data[10:12])
|
||||||
copy(p.Src, data[12:12+srclen])
|
|
||||||
}
|
p.Src = make(net.IP, 4)
|
||||||
if dstlen > 0 {
|
copy(p.Src, data[12:16])
|
||||||
p.Dst = make(net.IP, dstlen)
|
p.Dst = make(net.IP, 4)
|
||||||
copy(p.Dst, data[12+srclen:12+srclen+dstlen])
|
copy(p.Dst, data[16:20])
|
||||||
}
|
copy(p.Hash[:], data[20:52])
|
||||||
copy(p.Hash[:], data[12+srclen+dstlen:12+srclen+dstlen+32])
|
p.rembytes -= uint16(copy(p.Data[p.Flags<<3:], data[52:]))
|
||||||
p.Data = data[12+srclen+dstlen+32:]
|
|
||||||
return nil
|
complete = p.rembytes == 0
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal 将自身数据编码为 []byte
|
// Marshal 将自身数据编码为 []byte
|
||||||
func (p *Packet) Marshal(src net.IP) []byte {
|
// offset 必须为 8 的倍数,表示偏移的 8 位
|
||||||
|
func (p *Packet) Marshal(src net.IP, offset uint16, dontfrag, hasmore bool) []byte {
|
||||||
p.TTL--
|
p.TTL--
|
||||||
if p.TTL == 0 {
|
if p.TTL == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.DataSZ = uint32(len(p.Data))
|
p.DataSZ = uint16(len(p.Data))
|
||||||
if src != nil {
|
if src != nil {
|
||||||
p.Src = src
|
p.Src = src
|
||||||
|
offset >>= 3
|
||||||
|
if dontfrag {
|
||||||
|
offset |= 0x4000
|
||||||
|
}
|
||||||
|
if hasmore {
|
||||||
|
offset |= 0x2000
|
||||||
|
}
|
||||||
|
p.Flags = offset
|
||||||
}
|
}
|
||||||
|
|
||||||
packet := make([]byte, 52+len(p.Data))
|
packet := make([]byte, 52+len(p.Data))
|
||||||
binary.LittleEndian.PutUint32(packet[:4], p.DataSZ)
|
binary.LittleEndian.PutUint16(packet[:2], p.Ver)
|
||||||
|
binary.LittleEndian.PutUint16(packet[2:4], p.DataSZ)
|
||||||
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
|
binary.LittleEndian.PutUint16(packet[4:6], (uint16(p.TTL)<<8)|uint16(p.Proto))
|
||||||
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
|
binary.LittleEndian.PutUint16(packet[6:8], p.SrcPort)
|
||||||
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
|
binary.LittleEndian.PutUint16(packet[8:10], p.DstPort)
|
||||||
binary.LittleEndian.PutUint16(packet[10:12], 0x0404)
|
binary.LittleEndian.PutUint16(packet[10:12], p.Flags)
|
||||||
copy(packet[12:16], p.Src.To4())
|
copy(packet[12:16], p.Src.To4())
|
||||||
copy(packet[16:20], p.Dst.To4())
|
copy(packet[16:20], p.Dst.To4())
|
||||||
copy(packet[20:52], p.Hash[:])
|
copy(packet[20:52], p.Hash[:])
|
||||||
|
|||||||
@@ -75,20 +75,20 @@ func (l *Link) Read() *head.Packet {
|
|||||||
|
|
||||||
// Write 向 peer 发包
|
// Write 向 peer 发包
|
||||||
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
|
func (l *Link) Write(p *head.Packet, istransfer bool) (n int, err error) {
|
||||||
if len(p.Data) <= (32768 - 64) {
|
if len(p.Data) <= int(l.me.mtu) {
|
||||||
return l.write(p, istransfer)
|
return l.write(p, 0, istransfer, false)
|
||||||
}
|
}
|
||||||
data := p.Data
|
data := p.Data
|
||||||
offset := 0
|
offset := 0
|
||||||
for len(data) > (32768 - 64) {
|
for len(data) > int(l.me.mtu) {
|
||||||
packet := *p
|
packet := *p
|
||||||
packet.Data = data[offset*(32768-64) : (offset+1)*(32768-64)]
|
packet.Data = data[offset*int(l.me.mtu) : (offset+1)*int(l.me.mtu)]
|
||||||
i, err := l.write(&packet, istransfer)
|
i, err := l.write(&packet, uint16(offset), istransfer, true)
|
||||||
n += i
|
n += i
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
data = data[(offset+1)*(32768-64):]
|
data = data[(offset+1)*int(l.me.mtu):]
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
@@ -107,14 +107,17 @@ func (l *Link) String() (n string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write 向 peer 发一个包
|
// write 向 peer 发一个包
|
||||||
func (l *Link) write(p *head.Packet, istransfer bool) (n int, err error) {
|
func (l *Link) write(p *head.Packet, offset uint16, istransfer, hasmore bool) (n int, err error) {
|
||||||
var d []byte
|
var d []byte
|
||||||
if istransfer {
|
if istransfer {
|
||||||
d = p.Marshal(nil)
|
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.me.mtu) {
|
||||||
|
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
|
||||||
|
}
|
||||||
|
d = p.Marshal(nil, 0, false, false)
|
||||||
} else {
|
} else {
|
||||||
p.FillHash()
|
p.FillHash()
|
||||||
p.Data = l.Encode(p.Data)
|
p.Data = l.Encode(p.Data)
|
||||||
d = p.Marshal(l.me.me)
|
d = p.Marshal(l.me.me, offset, false, hasmore)
|
||||||
}
|
}
|
||||||
if d == nil {
|
if d == nil {
|
||||||
return 0, errors.New("[link] ttl exceeded")
|
return 0, errors.New("[link] ttl exceeded")
|
||||||
|
|||||||
@@ -20,9 +20,8 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
|||||||
n, addr, err := conn.ReadFromUDP(lbf)
|
n, addr, err := conn.ReadFromUDP(lbf)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
lbf = lbf[:n]
|
lbf = lbf[:n]
|
||||||
packet := head.Packet{}
|
packet := m.wait(lbf)
|
||||||
err = packet.Unmarshal(lbf)
|
if packet != nil {
|
||||||
if err == nil {
|
|
||||||
r := int(packet.DataSZ) - len(packet.Data)
|
r := int(packet.DataSZ) - len(packet.Data)
|
||||||
if r > 0 {
|
if r > 0 {
|
||||||
remain, err := readAll(conn, r)
|
remain, err := readAll(conn, r)
|
||||||
@@ -60,16 +59,16 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
|||||||
}
|
}
|
||||||
case head.ProtoNotify:
|
case head.ProtoNotify:
|
||||||
logrus.Infoln("[link] recv notify")
|
logrus.Infoln("[link] recv notify")
|
||||||
p.onNotify(&packet)
|
p.onNotify(packet)
|
||||||
case head.ProtoQuery:
|
case head.ProtoQuery:
|
||||||
logrus.Infoln("[link] recv query")
|
logrus.Infoln("[link] recv query")
|
||||||
p.onQuery(&packet)
|
p.onQuery(packet)
|
||||||
case head.ProtoData:
|
case head.ProtoData:
|
||||||
if p.pipe != nil {
|
if p.pipe != nil {
|
||||||
p.pipe <- &packet
|
p.pipe <- packet
|
||||||
logrus.Infoln("[link] deliver to pipe of", p.peerip)
|
logrus.Infoln("[link] deliver to pipe of", p.peerip)
|
||||||
} else {
|
} else {
|
||||||
m.pipe <- &packet
|
m.pipe <- packet
|
||||||
logrus.Infoln("[link] deliver to pipe of me")
|
logrus.Infoln("[link] deliver to pipe of me")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -81,7 +80,7 @@ func (m *Me) listen() (conn *net.UDPConn, err error) {
|
|||||||
} else if p.Accept(packet.Dst) {
|
} else if p.Accept(packet.Dst) {
|
||||||
if p.allowtrans {
|
if p.allowtrans {
|
||||||
// 转发
|
// 转发
|
||||||
n, err = p.Write(&packet, true)
|
n, err = p.Write(packet, true)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
|
logrus.Infoln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -32,10 +32,17 @@ type Me struct {
|
|||||||
pipe chan *head.Packet
|
pipe chan *head.Packet
|
||||||
// 本机路由表
|
// 本机路由表
|
||||||
router *Router
|
router *Router
|
||||||
|
// 本机未接收完全分片池
|
||||||
|
recving map[[32]byte]*head.Packet
|
||||||
|
recvmu sync.Mutex
|
||||||
|
// 超时定时器
|
||||||
|
clock map[*head.Packet]uint8
|
||||||
|
// 本机上层配置
|
||||||
|
srcport, dstport, mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMe 设置本机参数
|
// NewMe 设置本机参数
|
||||||
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool) (m Me) {
|
func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipeinlink bool, srcport, dstport, mtu uint16) (m Me) {
|
||||||
m.privKey = *privateKey
|
m.privKey = *privateKey
|
||||||
var err error
|
var err error
|
||||||
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
|
m.myend, err = net.ResolveUDPAddr("udp", myEndpoint)
|
||||||
@@ -62,5 +69,20 @@ func NewMe(privateKey *[32]byte, myipwithmask string, myEndpoint string, nopipei
|
|||||||
}
|
}
|
||||||
m.router.SetDefault(nil)
|
m.router.SetDefault(nil)
|
||||||
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
|
m.loop = m.AddPeer(m.me.String(), nil, "127.0.0.1:56789", []string{myipwithmask}, 0, false, nopipeinlink)
|
||||||
|
m.srcport = srcport
|
||||||
|
m.dstport = dstport
|
||||||
|
m.mtu = mtu
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Me) SrcPort() uint16 {
|
||||||
|
return m.srcport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Me) DstPort() uint16 {
|
||||||
|
return m.dstport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Me) MTU() uint16 {
|
||||||
|
return m.mtu
|
||||||
|
}
|
||||||
|
|||||||
77
gold/link/recv.go
Normal file
77
gold/link/recv.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package link
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/fumiama/WireGold/gold/head"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *Me) initrecvpool() {
|
||||||
|
if m.recving == nil {
|
||||||
|
m.recving = make(map[[32]byte]*head.Packet, 128)
|
||||||
|
}
|
||||||
|
// 超时定时器
|
||||||
|
m.clock = make(map[*head.Packet]uint8, 128)
|
||||||
|
var delhs []*head.Packet
|
||||||
|
t := time.NewTicker(time.Second)
|
||||||
|
for range t.C {
|
||||||
|
m.recvmu.Lock()
|
||||||
|
for k, v := range m.clock {
|
||||||
|
if v > 10 { // 10s
|
||||||
|
delete(m.recving, k.Hash)
|
||||||
|
delhs = append(delhs, k)
|
||||||
|
} else {
|
||||||
|
m.clock[k]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, k := range delhs {
|
||||||
|
delete(m.clock, k)
|
||||||
|
logrus.Warnln("[recv] drop timeout packet from", k.Src)
|
||||||
|
}
|
||||||
|
delhs = delhs[:0]
|
||||||
|
m.recvmu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Me) wait(data []byte) *head.Packet {
|
||||||
|
flags := binary.LittleEndian.Uint16(data[10:12])
|
||||||
|
if flags == 0 || flags == 0x4000 {
|
||||||
|
h := &head.Packet{}
|
||||||
|
_, err := h.Unmarshal(data)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("[recv] unmarshal err:", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
m.recvmu.Lock()
|
||||||
|
defer m.recvmu.Unlock()
|
||||||
|
hashd := data[20:52]
|
||||||
|
hsh := *(*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&hashd)))
|
||||||
|
h, ok := m.recving[hsh]
|
||||||
|
if ok {
|
||||||
|
ok, err := h.Unmarshal(data)
|
||||||
|
if err == nil {
|
||||||
|
if ok {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
m.clock[h] = 0
|
||||||
|
} else {
|
||||||
|
logrus.Errorln("[recv] unmarshal err:", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h = &head.Packet{}
|
||||||
|
_, err := h.Unmarshal(data)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("[recv] unmarshal err:", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.recving[hsh] = h
|
||||||
|
m.clock[h] = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
10
lower/nic.go
10
lower/nic.go
@@ -58,8 +58,8 @@ func (nc *NIC) Start(m *link.Me) {
|
|||||||
logrus.Infoln("[lower] recv write", n, "bytes packet to nic")
|
logrus.Infoln("[lower] recv write", n, "bytes packet to nic")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
buf := make([]byte, 65536) // 永远不可能超界
|
buf := make([]byte, m.MTU()+64) // 增加报头长度与 TEA 冗余
|
||||||
for nc.hasstart { // 从 NIC 发送
|
for nc.hasstart { // 从 NIC 发送
|
||||||
packet := buf
|
packet := buf
|
||||||
n, err := nc.ifce.Read(packet)
|
n, err := nc.ifce.Read(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -115,15 +115,13 @@ func send(m *link.Me, packet []byte) (n int, rem []byte) {
|
|||||||
packet = packet[:totl]
|
packet = packet[:totl]
|
||||||
n = int(totl)
|
n = int(totl)
|
||||||
dst := waterutil.IPv4Destination(packet)
|
dst := waterutil.IPv4Destination(packet)
|
||||||
srcport := waterutil.IPv4SourcePort(packet)
|
logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(m.SrcPort())), "to", dst.String()+":"+strconv.Itoa(int(m.DstPort())))
|
||||||
dstport := waterutil.IPv4DestinationPort(packet)
|
|
||||||
logrus.Infoln("[lower] sending", len(packet), "bytes packet from :"+strconv.Itoa(int(srcport)), "to", dst.String()+":"+strconv.Itoa(int(dstport)))
|
|
||||||
lnk, err := m.Connect(dst.String())
|
lnk, err := m.Connect(dst.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err)
|
logrus.Warnln("[lower] connect to peer", dst.String(), "err:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = lnk.Write(head.NewPacket(head.ProtoData, srcport, dst, dstport, packet), false)
|
_, err = lnk.Write(head.NewPacket(head.ProtoData, m.SrcPort(), dst, m.DstPort(), packet), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Warnln("[lower] write to peer", dst.String(), "err:", err)
|
logrus.Warnln("[lower] write to peer", dst.String(), "err:", err)
|
||||||
}
|
}
|
||||||
|
|||||||
91
main.go
91
main.go
@@ -1,22 +1,20 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
base14 "github.com/fumiama/go-base16384"
|
base14 "github.com/fumiama/go-base16384"
|
||||||
curve "github.com/fumiama/go-x25519"
|
curve "github.com/fumiama/go-x25519"
|
||||||
|
|
||||||
"github.com/fumiama/WireGold/config"
|
"github.com/fumiama/WireGold/config"
|
||||||
"github.com/fumiama/WireGold/gold/link"
|
|
||||||
"github.com/fumiama/WireGold/helper"
|
"github.com/fumiama/WireGold/helper"
|
||||||
"github.com/fumiama/WireGold/lower"
|
"github.com/fumiama/WireGold/upper"
|
||||||
|
"github.com/fumiama/WireGold/upper/services/wg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const suffix32 = "㴄"
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
help := flag.Bool("h", false, "display this help")
|
help := flag.Bool("h", false, "display this help")
|
||||||
gen := flag.Bool("g", false, "generate key pair")
|
gen := flag.Bool("g", false, "generate key pair")
|
||||||
@@ -44,16 +42,11 @@ func main() {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
if helper.IsNotExist(*file) {
|
if helper.IsNotExist(*file) {
|
||||||
f, err := os.Create(*file)
|
f := new(bytes.Buffer)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var r string
|
var r string
|
||||||
fmt.Print("IP: ")
|
fmt.Print("IP: ")
|
||||||
fmt.Scanln(&r)
|
fmt.Scanln(&r)
|
||||||
if r == "" {
|
if r == "" {
|
||||||
f.Close()
|
|
||||||
os.Remove(*file)
|
|
||||||
fmt.Println("nil ip")
|
fmt.Println("nil ip")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -63,8 +56,6 @@ func main() {
|
|||||||
fmt.Print("SubNet: ")
|
fmt.Print("SubNet: ")
|
||||||
fmt.Scanln(&r)
|
fmt.Scanln(&r)
|
||||||
if r == "" {
|
if r == "" {
|
||||||
f.Close()
|
|
||||||
os.Remove(*file)
|
|
||||||
fmt.Println("nil subnet")
|
fmt.Println("nil subnet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -74,8 +65,6 @@ func main() {
|
|||||||
fmt.Print("PrivateKey: ")
|
fmt.Print("PrivateKey: ")
|
||||||
fmt.Scanln(&r)
|
fmt.Scanln(&r)
|
||||||
if r == "" {
|
if r == "" {
|
||||||
f.Close()
|
|
||||||
os.Remove(*file)
|
|
||||||
fmt.Println("nil private key")
|
fmt.Println("nil private key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -85,15 +74,18 @@ func main() {
|
|||||||
fmt.Print("EndPoint: ")
|
fmt.Print("EndPoint: ")
|
||||||
fmt.Scanln(&r)
|
fmt.Scanln(&r)
|
||||||
if r == "" {
|
if r == "" {
|
||||||
f.Close()
|
|
||||||
os.Remove(*file)
|
|
||||||
fmt.Println("nil endpoint")
|
fmt.Println("nil endpoint")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.WriteString("EndPoint: " + r + "\n")
|
f.WriteString("EndPoint: " + r + "\n")
|
||||||
r = ""
|
r = ""
|
||||||
|
|
||||||
f.Close()
|
cfgf, err := os.Create(*file)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
cfgf.Write(f.Bytes())
|
||||||
|
cfgf.Close()
|
||||||
}
|
}
|
||||||
c := config.Parse(*file)
|
c := config.Parse(*file)
|
||||||
if c.IP == "" {
|
if c.IP == "" {
|
||||||
@@ -108,73 +100,18 @@ func main() {
|
|||||||
if c.EndPoint == "" {
|
if c.EndPoint == "" {
|
||||||
displayHelp("nil endpoint")
|
displayHelp("nil endpoint")
|
||||||
}
|
}
|
||||||
var key [32]byte
|
w, err := wg.NewWireGold(&c)
|
||||||
k, err := base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
n := copy(key[:], base14.Decode(k))
|
|
||||||
if n != 32 {
|
|
||||||
displayHelp("private key length is not 32")
|
|
||||||
}
|
|
||||||
|
|
||||||
if *showp {
|
if *showp {
|
||||||
c := curve.Get(key[:])
|
fmt.Println("PublicKey:", w.PublicKey)
|
||||||
pubk, err := base14.UTF16be2utf8(base14.Encode((*c.Public())[:]))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
fmt.Println("PublicKey:", helper.BytesToString(pubk[:57]))
|
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
cidrsmap := make(map[string]bool, 32)
|
defer w.Stop()
|
||||||
_, mysubnet, err := net.ParseCIDR(c.SubNet)
|
w.Run(upper.ServiceWireGold, upper.ServiceWireGold, 32768-64)
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
for _, p := range c.Peers {
|
|
||||||
for _, ip := range p.AllowedIPs {
|
|
||||||
ipnet, _, err := net.ParseCIDR(ip)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if !mysubnet.Contains(ipnet) {
|
|
||||||
cidrsmap[ip] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cidrs := make([]string, len(cidrsmap))
|
|
||||||
i := 0
|
|
||||||
for k := range cidrsmap {
|
|
||||||
cidrs[i] = k
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
|
|
||||||
nic := lower.NewNIC(c.IP, c.SubNet, cidrs...)
|
|
||||||
me := link.NewMe(&key, c.IP+"/32", c.EndPoint, true)
|
|
||||||
|
|
||||||
for _, peer := range c.Peers {
|
|
||||||
var peerkey [32]byte
|
|
||||||
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
n := copy(peerkey[:], base14.Decode(k))
|
|
||||||
if n != 32 {
|
|
||||||
panic("peer public key length is not 32")
|
|
||||||
}
|
|
||||||
me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
nic.Up()
|
|
||||||
defer func() {
|
|
||||||
nic.Stop()
|
|
||||||
nic.Down()
|
|
||||||
nic.Destroy()
|
|
||||||
}()
|
|
||||||
|
|
||||||
nic.Start(&me)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func displayHelp(hint string) {
|
func displayHelp(hint string) {
|
||||||
|
|||||||
@@ -8,9 +8,16 @@ const (
|
|||||||
ServiceNull = iota
|
ServiceNull = iota
|
||||||
// ServiceTunnel 管道通信服务
|
// ServiceTunnel 管道通信服务
|
||||||
ServiceTunnel
|
ServiceTunnel
|
||||||
|
// ServiceWireGold 虚拟组网服务
|
||||||
|
ServiceWireGold
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service interface {
|
type Service interface {
|
||||||
Create(peer string, srcport, destport, mtu uint16) (Service, error)
|
// Start 无阻塞运行
|
||||||
io.ReadWriteCloser
|
Start(srcport, destport, mtu uint16)
|
||||||
|
// Run 阻塞运行
|
||||||
|
Run(srcport, destport, mtu uint16)
|
||||||
|
// Stop 停止
|
||||||
|
Stop()
|
||||||
|
io.ReadWriter
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,24 +21,36 @@ type Tunnel struct {
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(me *link.Me, peer string, srcport, destport, mtu uint16) (s Tunnel, err error) {
|
func Create(me *link.Me, peer string) (s Tunnel, err error) {
|
||||||
logrus.Infoln("[tunnel] create from", srcport, "to", destport)
|
|
||||||
s.l, err = me.Connect(peer)
|
s.l, err = me.Connect(peer)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.in = make(chan []byte, 4)
|
s.in = make(chan []byte, 4)
|
||||||
s.out = make(chan []byte, 4)
|
s.out = make(chan []byte, 4)
|
||||||
s.peerip = net.ParseIP(peer)
|
s.peerip = net.ParseIP(peer)
|
||||||
s.src = srcport
|
|
||||||
s.dest = destport
|
|
||||||
s.mtu = mtu
|
|
||||||
go s.handleWrite()
|
|
||||||
go s.handleRead()
|
|
||||||
} else {
|
} else {
|
||||||
logrus.Errorln("[tunnel] create err:", err)
|
logrus.Errorln("[tunnel] create err:", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Tunnel) Start(srcport, destport, mtu uint16) {
|
||||||
|
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
|
||||||
|
s.src = srcport
|
||||||
|
s.dest = destport
|
||||||
|
s.mtu = mtu
|
||||||
|
go s.handleWrite()
|
||||||
|
go s.handleRead()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Tunnel) Run(srcport, destport, mtu uint16) {
|
||||||
|
logrus.Infoln("[tunnel] start from", srcport, "to", destport)
|
||||||
|
s.src = srcport
|
||||||
|
s.dest = destport
|
||||||
|
s.mtu = mtu
|
||||||
|
go s.handleWrite()
|
||||||
|
s.handleRead()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Tunnel) Write(p []byte) (int, error) {
|
func (s *Tunnel) Write(p []byte) (int, error) {
|
||||||
s.in <- p
|
s.in <- p
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
@@ -63,10 +75,9 @@ func (s *Tunnel) Read(p []byte) (int, error) {
|
|||||||
return 0, errors.New("reading reaches nil")
|
return 0, errors.New("reading reaches nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Tunnel) Close() error {
|
func (s *Tunnel) Stop() {
|
||||||
s.l.Close()
|
s.l.Close()
|
||||||
close(s.in)
|
close(s.in)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Tunnel) handleWrite() {
|
func (s *Tunnel) handleWrite() {
|
||||||
|
|||||||
@@ -27,18 +27,20 @@ func TestTunnel(t *testing.T) {
|
|||||||
t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:]))
|
t.Log("peer priv key:", hex.EncodeToString(peerpk.Private()[:]))
|
||||||
t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:]))
|
t.Log("peer publ key:", hex.EncodeToString(peerpk.Public()[:]))
|
||||||
|
|
||||||
m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false)
|
m := link.NewMe(selfpk.Private(), "192.168.1.2/32", "127.0.0.1:1236", false, 1, 1, 4096)
|
||||||
m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false)
|
m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", []string{"192.168.1.3/32"}, 0, false, false)
|
||||||
p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false)
|
p := link.NewMe(peerpk.Private(), "192.168.1.3/32", "127.0.0.1:1237", false, 1, 1, 4096)
|
||||||
p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false)
|
p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", []string{"192.168.1.2/32"}, 0, false, false)
|
||||||
tunnme, err := Create(&m, "192.168.1.3", 1, 1, 4096)
|
tunnme, err := Create(&m, "192.168.1.3")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tunnpeer, err := Create(&p, "192.168.1.2", 1, 1, 4096)
|
tunnme.Start(1, 1, 4096)
|
||||||
|
tunnpeer, err := Create(&p, "192.168.1.2")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
tunnpeer.Start(1, 1, 4096)
|
||||||
|
|
||||||
sendb := ([]byte)("1234")
|
sendb := ([]byte)("1234")
|
||||||
tunnme.Write(sendb)
|
tunnme.Write(sendb)
|
||||||
@@ -68,4 +70,7 @@ func TestTunnel(t *testing.T) {
|
|||||||
if string(sendb) != string(buf) {
|
if string(sendb) != string(buf) {
|
||||||
t.Fatal("error: recv 131072 bytes data")
|
t.Fatal("error: recv 131072 bytes data")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tunnme.Stop()
|
||||||
|
tunnpeer.Stop()
|
||||||
}
|
}
|
||||||
|
|||||||
107
upper/services/wg/wg.go
Normal file
107
upper/services/wg/wg.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package wg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
base14 "github.com/fumiama/go-base16384"
|
||||||
|
curve "github.com/fumiama/go-x25519"
|
||||||
|
|
||||||
|
"github.com/fumiama/WireGold/config"
|
||||||
|
"github.com/fumiama/WireGold/gold/link"
|
||||||
|
"github.com/fumiama/WireGold/helper"
|
||||||
|
"github.com/fumiama/WireGold/lower"
|
||||||
|
)
|
||||||
|
|
||||||
|
const suffix32 = "㴄"
|
||||||
|
|
||||||
|
type WG struct {
|
||||||
|
c *config.Config
|
||||||
|
key [32]byte
|
||||||
|
PublicKey string
|
||||||
|
nic *lower.NIC
|
||||||
|
me link.Me
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWireGold(c *config.Config) (wg WG, err error) {
|
||||||
|
wg.c = c
|
||||||
|
|
||||||
|
var k []byte
|
||||||
|
k, err = base14.UTF82utf16be(helper.StringToBytes(c.PrivateKey + suffix32))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n := copy(wg.key[:], base14.Decode(k))
|
||||||
|
if n != 32 {
|
||||||
|
err = errors.New("private key length is not 32")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cur := curve.Get(wg.key[:])
|
||||||
|
pubk, err := base14.UTF16be2utf8(base14.Encode((*cur.Public())[:]))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.PublicKey = helper.BytesToString(pubk[:57])
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *WG) Start(srcport, destport, mtu uint16) {
|
||||||
|
wg.init(srcport, destport, mtu)
|
||||||
|
wg.nic.Up()
|
||||||
|
go wg.nic.Start(&wg.me)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *WG) Run(srcport, destport, mtu uint16) {
|
||||||
|
wg.init(srcport, destport, mtu)
|
||||||
|
wg.nic.Up()
|
||||||
|
wg.nic.Start(&wg.me)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *WG) Stop() {
|
||||||
|
wg.nic.Stop()
|
||||||
|
wg.nic.Down()
|
||||||
|
wg.nic.Destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wg *WG) init(srcport, destport, mtu uint16) {
|
||||||
|
cidrsmap := make(map[string]bool, 32)
|
||||||
|
_, mysubnet, err := net.ParseCIDR(wg.c.SubNet)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
for _, p := range wg.c.Peers {
|
||||||
|
for _, ip := range p.AllowedIPs {
|
||||||
|
ipnet, _, err := net.ParseCIDR(ip)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if !mysubnet.Contains(ipnet) {
|
||||||
|
cidrsmap[ip] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cidrs := make([]string, len(cidrsmap))
|
||||||
|
i := 0
|
||||||
|
for k := range cidrsmap {
|
||||||
|
cidrs[i] = k
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.nic = lower.NewNIC(wg.c.IP, wg.c.SubNet, cidrs...)
|
||||||
|
wg.me = link.NewMe(&wg.key, wg.c.IP+"/32", wg.c.EndPoint, true, srcport, destport, mtu)
|
||||||
|
|
||||||
|
for _, peer := range wg.c.Peers {
|
||||||
|
var peerkey [32]byte
|
||||||
|
k, err := base14.UTF82utf16be(helper.StringToBytes(peer.PublicKey + suffix32))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
n := copy(peerkey[:], base14.Decode(k))
|
||||||
|
if n != 32 {
|
||||||
|
panic("peer public key length is not 32")
|
||||||
|
}
|
||||||
|
wg.me.AddPeer(peer.IP, &peerkey, peer.EndPoint, peer.AllowedIPs, peer.KeepAliveSeconds, peer.AllowTrans, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user