diff --git a/README.md b/README.md index 33fc543..f3f59d1 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Peers: IP: "192.168.233.2" SubNet: 192.168.233.0/24 PublicKey: 徯萃嵾爻燸攗窍褃冔蒔犡緇袿屿組待族砇嘀 + PresharedKey: 瀸敀爅崾嘊嵜紼樴稍毯攣矐訷蟷扛嬋庩崛昀 EndPoint: 1.2.3.4:56789 AllowedIPs: ["192.168.233.2/32"] KeepAliveSeconds: 0 @@ -53,9 +54,10 @@ Peers: IP: "192.168.233.3" SubNet: 192.168.233.0/24 PublicKey: 牢喨粷詸衭譛浾蘹櫠砙杹蟫瑳叩刋橋経挵蘀 + PresharedKey: 竅琚喫従痸告烈兇厕趭萨假蔛瀇譄施烸蝫瘀 EndPoint: "" AllowedIPs: ["192.168.233.3/32"] MTU: 752 KeepAliveSeconds: 0 AllowTrans: false -``` \ No newline at end of file +``` diff --git a/config/cfg.go b/config/cfg.go index d65301a..9c11d8e 100644 --- a/config/cfg.go +++ b/config/cfg.go @@ -23,6 +23,7 @@ type Peer struct { IP string `yaml:"IP"` SubNet string `yaml:"SubNet"` PublicKey string `yaml:"PublicKey"` + PresharedKey string `yaml:"PresharedKey"` EndPoint string `yaml:"EndPoint"` AllowedIPs []string `yaml:"AllowedIPs"` KeepAliveSeconds int64 `yaml:"KeepAliveSeconds"` diff --git a/gold/head/packet.go b/gold/head/packet.go index 1b8bdd9..73dd79a 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -6,7 +6,6 @@ import ( "errors" "hash/crc64" "net" - "sync/atomic" "github.com/fumiama/WireGold/helper" blake2b "github.com/fumiama/blake2b-simd" @@ -17,7 +16,7 @@ import ( type Packet struct { // TeaTypeDataSZ len(Data) // 高 4 位指定加密所用 tea key - // 高 4-16 位是随机值 + // 高 4-16 位是递增值, 用于预共享密钥验证 additionalData // 不得超过 65507-head 字节 TeaTypeDataSZ uint32 // Proto 详见 head @@ -109,18 +108,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) { return } -var counter uint32 - // Marshal 将自身数据编码为 []byte // offset 必须为 8 的倍数,表示偏移的 8 位 -func (p *Packet) Marshal(src net.IP, teatype uint8, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) { +func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) { p.TTL-- if p.TTL == 0 { return nil, nil } if src != nil { - p.TeaTypeDataSZ = uint32(teatype)<<28 | (atomic.AddUint32(&counter, 1)<<16)&0x0fff0000 | datasz + p.TeaTypeDataSZ = uint32(teatype)<<28 | (uint32(additional&0x0fff) << 16) | datasz&0xffff p.Src = src offset &= 0x1fff if dontfrag { @@ -171,6 +168,11 @@ func (p *Packet) IsVaildHash() bool { return sum == p.Hash } +// AdditionalData 获得 packet 的 additionalData +func (p *Packet) AdditionalData() uint16 { + return uint16((p.TeaTypeDataSZ >> 16) & 0x0fff) +} + // Put 将自己放回池中 func (p *Packet) Put() { PutPacket(p) diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 9e0805e..299b7c5 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -1,5 +1,10 @@ package link +import ( + "crypto/rand" + "encoding/binary" +) + // Encode 使用 TEA 加密 func (l *Link) Encode(teatype uint8, b []byte) (eb []byte) { if b == nil || teatype >= 16 { @@ -29,3 +34,34 @@ func (l *Link) Decode(teatype uint8, b []byte) (db []byte) { db = l.key[teatype].Decrypt(b) return } + +// EncodePreshared 使用 chacha20poly1305 加密 +func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) { + nsz := l.aead.NonceSize() + // Select a random nonce, and leave capacity for the ciphertext. + nonce := make([]byte, nsz, nsz+len(b)+l.aead.Overhead()) + _, err := rand.Read(nonce) + if err != nil { + return + } + // Encrypt the message and append the ciphertext to the nonce. + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], additional) + eb = l.aead.Seal(nonce, nonce, b, buf[:]) + return +} + +// DecodePreshared 使用 chacha20poly1305 解密 +func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) { + nsz := l.aead.NonceSize() + if len(b) < nsz { // ciphertext too short + return + } + // Split nonce and ciphertext. + nonce, ciphertext := b[:nsz], b[nsz:] + // Decrypt the message and check it wasn't tampered with. + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], additional) + db, _ = l.aead.Open(nil, nonce, ciphertext, buf[:]) + return +} diff --git a/gold/link/link.go b/gold/link/link.go index 2269422..0a93e62 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -1,6 +1,7 @@ package link import ( + "crypto/cipher" "errors" "net" @@ -14,6 +15,8 @@ import ( type Link struct { // peer 的公钥 pubk *[32]byte + // 发包计数, 分片算一个 + sendcount uintptr // 收到的包的队列 // 没有下层 nic 时 // 包会分发到此 @@ -26,6 +29,8 @@ type Link struct { allowedips []*net.IPNet // 连接所用对称加密密钥 key []tea.TEA + // 连接所用预共享密钥 + aead cipher.AEAD // 本机信息 me *Me // 连接的状态,详见下方 const diff --git a/gold/link/listen.go b/gold/link/listen.go index f67b199..bb46ba4 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -43,40 +43,48 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) { sz := packet.TeaTypeDataSZ & 0x0000ffff r := int(sz) - len(packet.Data) if r > 0 { - logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it") + logrus.Warnln("[listen] packet from endpoint", addr, "is smaller than it declared: drop it") packet.Put() continue } p, ok := m.IsInPeer(packet.Src.String()) - logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) - // logrus.Debugln("[link] recv:", hex.EncodeToString(lbf)) + logrus.Debugln("[listen] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst) + // logrus.Debugln("[listen] recv:", hex.EncodeToString(lbf)) if !ok { - logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused") + logrus.Warnln("[listen] packet from", packet.Src, "to", packet.Dst, "is refused") packet.Put() continue } if p.endpoint == nil || p.endpoint.String() != addr.String() { - logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String()) + logrus.Infoln("[listen] set endpoint of peer", p.peerip, "to", addr.String()) p.endpoint = addr } switch { case p.IsToMe(packet.Dst): packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data) if !packet.IsVaildHash() { - logrus.Debugln("[link] drop invalid packet") + logrus.Debugln("[listen] drop invalid hash packet") packet.Put() continue } + if p.aead != nil { + packet.Data = p.DecodePreshared(packet.AdditionalData(), packet.Data) + if packet.Data == nil { + logrus.Debugln("[listen] drop invalid additional data packet") + packet.Put() + continue + } + } switch packet.Proto { case head.ProtoHello: switch p.status { case LINK_STATUS_DOWN: n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false) if err == nil { - logrus.Debugln("[link] send", n, "bytes hello ack packet") + logrus.Debugln("[listen] send", n, "bytes hello ack packet") p.status = LINK_STATUS_HALFUP } else { - logrus.Errorln("[link] send hello ack packet error:", err) + logrus.Errorln("[listen] send hello ack packet error:", err) } case LINK_STATUS_HALFUP: p.status = LINK_STATUS_UP @@ -84,47 +92,47 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) { } packet.Put() case head.ProtoNotify: - logrus.Infoln("[link] recv notify from", packet.Src) + logrus.Infoln("[listen] recv notify from", packet.Src) go p.onNotify(packet.Data) packet.Put() case head.ProtoQuery: - logrus.Infoln("[link] recv query from", packet.Src) + logrus.Infoln("[listen] recv query from", packet.Src) go p.onQuery(packet.Data) packet.Put() case head.ProtoData: if p.pipe != nil { p.pipe <- packet - logrus.Debugln("[link] deliver to pipe of", p.peerip) + logrus.Debugln("[listen] deliver to pipe of", p.peerip) } else { m.nic.Write(packet.Data) - logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic") + logrus.Debugln("[listen] deliver", len(packet.Data), "bytes data to nic") packet.Put() } default: - logrus.Warnln("[link] recv unknown proto:", packet.Proto) + logrus.Warnln("[listen] recv unknown proto:", packet.Proto) packet.Put() } case p.Accept(packet.Dst): if !p.allowtrans { - logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + logrus.Warnln("[listen] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) packet.Put() continue } // 转发 lnk := m.router.NextHop(packet.Dst.String()) if lnk == nil { - logrus.Warnln("[link] transfer drop packet: nil nexthop") + logrus.Warnln("[listen] transfer drop packet: nil nexthop") packet.Put() continue } n, err = lnk.WriteAndPut(packet, true) if err == nil { - logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) + logrus.Debugln("[listen] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort))) } else { - logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) + logrus.Errorln("[listen] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err) } default: - logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") + logrus.Warnln("[listen] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers") packet.Put() } } diff --git a/gold/link/nat.go b/gold/link/nat.go index f51d2b9..7df186e 100644 --- a/gold/link/nat.go +++ b/gold/link/nat.go @@ -16,14 +16,14 @@ import ( // 以秒为单位,小于等于 0 不发送 func (l *Link) keepAlive(dur int64) { if dur > 0 { - logrus.Infoln("[link.nat] start to keep alive") + logrus.Infoln("[nat] start to keep alive") t := time.NewTicker(time.Second * time.Duration(dur)) for range t.C { n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false) if err == nil { - logrus.Infoln("[link] send", n, "bytes keep alive packet") + logrus.Infoln("[nat] send", n, "bytes keep alive packet") } else { - logrus.Errorln("[link] send keep alive packet error:", err) + logrus.Errorln("[nat] send keep alive packet error:", err) } } } @@ -37,7 +37,7 @@ func (l *Link) onNotify(packet []byte) { notify := make(head.Notify, 32) err := json.Unmarshal(packet, ¬ify) if err != nil { - logrus.Errorln("[notify] json unmarshal err:", err) + logrus.Errorln("[nat] notify json unmarshal err:", err) return } // 2. endpoint注册 @@ -50,12 +50,12 @@ func (l *Link) onNotify(packet []byte) { if ok { if p.endpoint.String() != ep { p.endpoint = addr - logrus.Infoln("[notify] set ep of peer", peer, "to", ep) + logrus.Infoln("[nat] notify set ep of peer", peer, "to", ep) } continue } } - logrus.Debugln("[notify] drop invalid peer:", peer, "ep:", ep) + logrus.Debugln("[nat] notify drop invalid peer:", peer, "ep:", ep) } } @@ -69,7 +69,7 @@ func (l *Link) onQuery(packet []byte) { var peers head.Query err := json.Unmarshal(packet, &peers) if err != nil { - logrus.Errorln("[qurey] json unmarshal err:", err) + logrus.Errorln("[nat] query json unmarshal err:", err) return } @@ -84,7 +84,7 @@ func (l *Link) onQuery(packet []byte) { } } if len(notify) > 0 { - logrus.Infoln("[query] wrap", len(notify), "notify") + logrus.Infoln("[nat] query wrap", len(notify), "notify") w := helper.SelectWriter() json.NewEncoder(w).Encode(¬ify) l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false) @@ -103,10 +103,10 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) { } t := time.NewTicker(tick) for range t.C { - logrus.Infoln("[query] send query to", l.peerip) + logrus.Infoln("[nat] query send query to", l.peerip) _, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false) if err != nil { - logrus.Errorln("[query] write err:", err) + logrus.Errorln("[nat] query write err:", err) } } } diff --git a/gold/link/peer.go b/gold/link/peer.go index a7aecca..9c1fbf3 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -8,6 +8,7 @@ import ( curve "github.com/fumiama/go-x25519" tea "github.com/fumiama/gofastTEA" "github.com/sirupsen/logrus" + "golang.org/x/crypto/chacha20poly1305" ) type PeerConfig struct { @@ -15,6 +16,7 @@ type PeerConfig struct { EndPoint string AllowedIPs, Querys []string PubicKey *[32]byte + PresharedKey *[32]byte KeepAliveDur, QueryTick int64 MTU uint16 AllowTrans, NoPipe bool @@ -52,6 +54,13 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) { } } } + if cfg.PresharedKey != nil { + var err error + l.aead, err = chacha20poly1305.NewX(cfg.PresharedKey[:]) + if err != nil { + panic(err) + } + } if cfg.EndPoint != "" { e, err := net.ResolveUDPAddr("udp", cfg.EndPoint) if err != nil { diff --git a/gold/link/send.go b/gold/link/send.go index 8c76305..57d80b8 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/rand" + "sync/atomic" "github.com/fumiama/WireGold/gold/head" "github.com/sirupsen/logrus" @@ -12,16 +13,23 @@ import ( // WriteAndPut 向 peer 发包并将包放回缓存池 func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { teatype := uint8(rand.Intn(16)) + sndcnt := atomic.AddUintptr(&l.sendcount, 1) if len(p.Data) <= int(l.mtu) { if !istransfer { p.FillHash() + if l.aead != nil { + p.Data = l.EncodePreshared(uint16(sndcnt), p.Data) + } p.Data = l.Encode(teatype, p.Data) } defer p.Put() - return l.write(p, teatype, uint32(len(p.Data)), 0, istransfer, false) + return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false) } if !istransfer { p.FillHash() + if l.aead != nil { + p.Data = l.EncodePreshared(uint16(sndcnt), p.Data) + } p.Data = l.Encode(teatype, p.Data) } data := p.Data @@ -31,9 +39,9 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { packet := head.SelectPacket() *packet = *p for ; int(totl)-i > int(l.mtu); i += int(l.mtu) { - logrus.Debugln("[link] split frag", i, ":", i+int(l.mtu), ", remain:", int(totl)-i-int(l.mtu)) + logrus.Debugln("[send] split frag", i, ":", i+int(l.mtu), ", remain:", int(totl)-i-int(l.mtu)) packet.Data = data[:int(l.mtu)] - cnt, err := l.write(packet, teatype, totl, uint16(i>>3), istransfer, true) + cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, true) n += cnt if err != nil { return n, err @@ -43,33 +51,33 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { } packet.Put() p.Data = data - cnt, err := l.write(p, teatype, totl, uint16(i>>3), istransfer, false) + cnt, err := l.write(p, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, false) p.Put() n += cnt return n, err } // write 向 peer 发一个包 -func (l *Link) write(p *head.Packet, teatype uint8, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) { +func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) { var d []byte var cl func() if istransfer { if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.mtu) { return len(p.Data), errors.New("drop dont fragmnet big trans packet") } - d, cl = p.Marshal(nil, teatype, 0, 0, false, false) + d, cl = p.Marshal(nil, teatype, additional, 0, 0, false, false) } else { - d, cl = p.Marshal(l.me.me, teatype, datasz, offset, false, hasmore) + d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore) } if d == nil { - return 0, errors.New("[link] ttl exceeded") + return 0, errors.New("[send] ttl exceeded") } if err == nil { peerep := l.endpoint if peerep == nil { - return 0, errors.New("[link] nil endpoint of " + p.Dst.String()) + return 0, errors.New("[send] nil endpoint of " + p.Dst.String()) } - logrus.Debugln("[link] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) + logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset)) n, err = l.me.myep.WriteToUDP(d, peerep) cl() } diff --git a/main.go b/main.go index 7d4401a..7decd24 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/rand" "flag" "fmt" "os" @@ -19,6 +20,7 @@ import ( func main() { help := flag.Bool("h", false, "display this help") gen := flag.Bool("g", false, "generate key pair") + pshgen := flag.Bool("pg", false, "generate preshared key") showp := flag.Bool("p", false, "show my publickey") file := flag.String("c", "config.yaml", "specify conf file") debug := flag.Bool("d", false, "print debug logs") @@ -50,6 +52,19 @@ func main() { fmt.Println("PrivateKey:", helper.BytesToString(prvk[:57])) os.Exit(0) } + if *pshgen { + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + panic(err) + } + pshk, err := base14.UTF16BE2UTF8(base14.Encode(buf[:])) + if err != nil { + panic(err) + } + fmt.Println("PresharedKey:", helper.BytesToString(pshk[:57])) + os.Exit(0) + } if *logfile != "-" { f, err := os.Create(*logfile) if err != nil { diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index ca97302..0836a6e 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -107,7 +107,19 @@ func (wg *WG) init(srcport, dstport uint16) { } n := copy(peerkey[:], base14.Decode(k)) if n != 32 { - panic("peer public key length is not 32") + panic("peer public key length < 32") + } + var pshk *[32]byte + if peer.PresharedKey != "" { + k, err := base14.UTF82UTF16BE(helper.StringToBytes(peer.PresharedKey + suffix32)) + if err != nil { + panic(err) + } + pshk = &[32]byte{} + n := copy(pshk[:], base14.Decode(k)) + if n != 32 { + panic("peer preshared key length < 32") + } } wg.me.AddPeer(&link.PeerConfig{ PeerIP: peer.IP, @@ -115,6 +127,7 @@ func (wg *WG) init(srcport, dstport uint16) { AllowedIPs: peer.AllowedIPs, Querys: peer.QueryList, PubicKey: &peerkey, + PresharedKey: pshk, KeepAliveDur: peer.KeepAliveSeconds, QueryTick: peer.QuerySeconds, MTU: uint16(peer.MTU),