diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 5e83e89..ac3f11f 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -2,12 +2,14 @@ package link import ( "net" + "sync" "unsafe" tea "github.com/fumiama/gofastTEA" ) -var ( +// Me 是本机的抽象 +type Me struct { // 本机私钥 // 利用 Curve25519 生成 // https://pkg.go.dev/golang.org/x/crypto/curve25519 @@ -17,21 +19,34 @@ var ( me net.IP // 本机 endpoint myend *net.UDPAddr -) + // 本机活跃的所有连接 + connections map[string]*Link + // 读写同步锁 + connmapmu sync.RWMutex + // 本机监听的 endpoint + myconn *net.UDPConn + // 本机路由表 + router *Router +} -// SetMyself 设置本机参数 -func SetMyself(privateKey [32]byte, myIP string, myEndpoint string) { - privKey = privateKey +// NewMe 设置本机参数 +func NewMe(privateKey *[32]byte, myIP string, myEndpoint string) (m Me) { + m.privKey = *privateKey var err error - myend, err = net.ResolveUDPAddr("udp", myEndpoint) + m.myend, err = net.ResolveUDPAddr("udp", myEndpoint) if err != nil { panic(err) } - me = net.ParseIP(myIP) - myconn, err = listen() + m.me = net.ParseIP(myIP) + m.myconn, err = m.listen() if err != nil { panic(err) } + m.connections = make(map[string]*Link) + m.router = &Router{ + routetable: make(map[string][]*Link), + } + return } // Encode 使用 TEA 加密 @@ -44,7 +59,7 @@ func (l *Link) Encode(b []byte) (eb []byte, err error) { } else { // 在此处填写加密逻辑,密钥是l.key,输入是b,输出是eb // 不用写return,直接赋值给eb即可 - eb = (*tea.TEA)(unsafe.Pointer(&privKey)).Encrypt(b) + eb = (*tea.TEA)(unsafe.Pointer(l.key)).Encrypt(b) } return } @@ -59,7 +74,7 @@ func (l *Link) Decode(b []byte) (db []byte, err error) { } else { // 在此处填写解密逻辑,密钥是l.key,输入是b,输出是db // 不用写return,直接赋值给db即可 - db = (*tea.TEA)(unsafe.Pointer(&privKey)).Decrypt(b) + db = (*tea.TEA)(unsafe.Pointer(l.key)).Decrypt(b) } return } diff --git a/gold/link/link.go b/gold/link/link.go index 2f6373a..afc1cf7 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -3,7 +3,6 @@ package link import ( "errors" "net" - "sync" "github.com/sirupsen/logrus" @@ -35,6 +34,8 @@ type Link struct { status int // 连接所用对称加密密钥 key *[32]byte + // 本机信息 + me *Me } const ( @@ -43,18 +44,9 @@ const ( LINK_STATUS_UP ) -var ( - // 本机活跃的所有连接 - connections = make(map[string]*Link) - // 读写同步锁 - connmapmu sync.RWMutex - // 本机监听的 endpoint - myconn *net.UDPConn -) - // Connect 初始化与 peer 的连接 -func Connect(peer string) (*Link, error) { - p, ok := IsInPeer(net.ParseIP(peer).String()) +func (m *Me) Connect(peer string) (*Link, error) { + p, ok := m.IsInPeer(net.ParseIP(peer).String()) if ok { p.keepAlive() return p, nil @@ -64,9 +56,9 @@ func Connect(peer string) (*Link, error) { // Close 关闭到 peer 的连接 func (l *Link) Close() { - connmapmu.Lock() - delete(connections, l.peerip.String()) - connmapmu.Unlock() + l.me.connmapmu.Lock() + delete(l.me.connections, l.peerip.String()) + l.me.connmapmu.Unlock() l.status = LINK_STATUS_DOWN } @@ -80,10 +72,10 @@ func (l *Link) Write(p *head.Packet) (n int, err error) { p.Data, err = l.Encode(p.Data) if err == nil { var d []byte - d, err = p.Mashal(me.String(), l.peerip.String()) + d, err = p.Mashal(l.me.me.String(), l.peerip.String()) logrus.Debugln("[link] write data", string(d)) if err == nil { - n, err = myconn.WriteToUDP(d, l.NextHop(l.peerip).endpoint) + n, err = l.me.myconn.WriteToUDP(d, l.NextHop(l.peerip).endpoint) } } return diff --git a/gold/link/listen.go b/gold/link/listen.go index 997fbec..a59306b 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -8,8 +8,8 @@ import ( ) // 监听本机 endpoint -func listen() (conn *net.UDPConn, err error) { - conn, err = net.ListenUDP("udp", myend) +func (m *Me) listen() (conn *net.UDPConn, err error) { + conn, err = net.ListenUDP("udp", m.myend) if err == nil { go func() { listenbuff := make([]byte, 65536) @@ -28,7 +28,7 @@ func listen() (conn *net.UDPConn, err error) { packet.Data = append(packet.Data, remain...) } } - p, ok := IsInPeer(packet.Src) + p, ok := m.IsInPeer(packet.Src) logrus.Infoln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) logrus.Debugln("[link] recv:", string(lbf)) if p.pep == "" || p.pep != addr.String() { @@ -71,7 +71,7 @@ func listen() (conn *net.UDPConn, err error) { logrus.Infoln("[link] trans") } } else { - logrus.Infoln("[link] packet to", packet.Dst, "is refused", "(me:", me, ")") + logrus.Infoln("[link] packet to", packet.Dst, "is refused", "(me:", m.me, ")") } } } diff --git a/gold/link/peer.go b/gold/link/peer.go index c985f80..e89b72e 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -1,6 +1,7 @@ package link import ( + "fmt" "net" "unsafe" @@ -10,10 +11,10 @@ import ( ) // AddPeer 添加一个 peer -func AddPeer(peerip string, pubicKey *[32]byte, endPoint string, allowedIPs []string, keepAlive int64, allowTrans bool) (l *Link) { +func (m *Me) AddPeer(peerip string, pubicKey *[32]byte, endPoint string, allowedIPs []string, keepAlive int64, allowTrans bool) (l *Link) { peerip = net.ParseIP(peerip).String() var ok bool - l, ok = IsInPeer(peerip) + l, ok = m.IsInPeer(peerip) if ok { return } @@ -23,11 +24,13 @@ func AddPeer(peerip string, pubicKey *[32]byte, endPoint string, allowedIPs []st pipe: make(chan *head.Packet, 32), peerip: net.ParseIP(peerip), allowtrans: allowTrans, + me: m, } if pubicKey != nil { - c := curve.Get(privKey[:]) + c := curve.Get(m.privKey[:]) k, err := c.Shared(pubicKey) if err == nil { + fmt.Println(len(k)) l.key = (*[32]byte)(*(*unsafe.Pointer)(unsafe.Pointer(&k))) } } @@ -45,20 +48,20 @@ func AddPeer(peerip string, pubicKey *[32]byte, endPoint string, allowedIPs []st _, cidr, err := net.ParseCIDR(ipnet) if err == nil { l.allowedips = append(l.allowedips, cidr) - routetable[cidr.String()] = append(routetable[cidr.String()], l) + l.me.router.routetable[cidr.String()] = append(l.me.router.routetable[cidr.String()], l) } } } - connmapmu.Lock() - connections[peerip] = l - connmapmu.Unlock() + l.me.connmapmu.Lock() + l.me.connections[peerip] = l + l.me.connmapmu.Unlock() return } // IsInPeer 查找 peer 是否已经在册 -func IsInPeer(peer string) (p *Link, ok bool) { - connmapmu.RLock() - p, ok = connections[peer] - connmapmu.RUnlock() +func (m *Me) IsInPeer(peer string) (p *Link, ok bool) { + m.connmapmu.RLock() + p, ok = m.connections[peer] + m.connmapmu.RUnlock() return } diff --git a/gold/link/router.go b/gold/link/router.go index f25b342..6d2283d 100644 --- a/gold/link/router.go +++ b/gold/link/router.go @@ -5,10 +5,10 @@ import ( "sync" ) -var ( - routetable = make(map[string][]*Link) +type Router struct { + routetable map[string][]*Link routetablemu sync.RWMutex -) +} // Accept 判断是否应当接受 ip 发来的包 func (l *Link) Accept(ip net.IP) bool { @@ -22,7 +22,7 @@ func (l *Link) Accept(ip net.IP) bool { // IsToMe 判断是否是发给自己的包 func (l *Link) IsToMe(ip net.IP) bool { - return ip.Equal(me) + return ip.Equal(l.me.me) } // NextHop 得到前往 ip 的下一跳的 link diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index a3519ed..9cb9fd2 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -18,9 +18,9 @@ type Tunnel struct { dest uint16 } -func Create(peer string, srcport uint16, destport uint16) (s Tunnel, err error) { +func Create(me *link.Me, peer string, srcport uint16, destport uint16) (s Tunnel, err error) { logrus.Infoln("[tunnel] create from", srcport, "to", destport) - s.l, err = link.Connect(peer) + s.l, err = me.Connect(peer) if err == nil { s.in = make(chan []byte, 4) s.out = make(chan []byte, 4) diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 877daab..de80e11 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -3,25 +3,40 @@ package tunnel import ( "testing" - "github.com/fumiama/WireGold/gold/link" + curve "github.com/fumiama/go-x25519" "github.com/sirupsen/logrus" + + "github.com/fumiama/WireGold/gold/link" ) func TestTunnel(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) - link.SetMyself([32]byte{}, "192.168.1.2", "127.0.0.1:1236") - link.AddPeer("192.168.1.2", nil, "127.0.0.1:1236", nil, 0, false) - tunn, err := Create("192.168.1.2", 1, 1) + selfpk, err := curve.New(nil) if err != nil { - t.Error(err) - } else { - sendb := ([]byte)("1234") - tunn.Write(sendb) - p := make([]byte, 4) - tunn.Read(p) - if string(sendb) != string(p) { - t.Log("error: recv", p) - t.Fail() - } + panic(err) + } + peerpk, err := curve.New(nil) + if err != nil { + panic(err) + } + m := link.NewMe(selfpk.Private(), "192.168.1.2", "127.0.0.1:1236") + m.AddPeer("192.168.1.3", peerpk.Public(), "127.0.0.1:1237", nil, 0, false) + p := link.NewMe(peerpk.Private(), "192.168.1.3", "127.0.0.1:1237") + p.AddPeer("192.168.1.2", selfpk.Public(), "127.0.0.1:1236", nil, 0, false) + tunnme, err := Create(&m, "192.168.1.3", 1, 1) + if err != nil { + t.Fatal(err) + } + tunnpeer, err := Create(&p, "192.168.1.2", 1, 1) + if err != nil { + t.Fatal(err) + } + sendb := ([]byte)("1234") + tunnme.Write(sendb) + buf := make([]byte, 4) + tunnpeer.Read(buf) + if string(sendb) != string(buf) { + t.Log("error: recv", buf) + t.Fail() } }