1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-04 23:40:26 +08:00

feat: add preshared key

This commit is contained in:
源文雨
2023-08-03 18:01:48 +08:00
parent fda910cb4b
commit b35d2787ad
11 changed files with 145 additions and 46 deletions

View File

@@ -42,6 +42,7 @@ Peers:
IP: "192.168.233.2"
SubNet: 192.168.233.0/24
PublicKey: 徯萃嵾爻燸攗窍褃冔蒔犡緇袿屿組待族砇嘀
PresharedKey: 瀸敀爅崾嘊嵜紼樴稍毯攣矐訷蟷扛嬋庩崛昀
EndPoint: 1.2.3.4:56789
AllowedIPs: ["192.168.233.2/32"]
KeepAliveSeconds: 0
@@ -53,9 +54,10 @@ Peers:
IP: "192.168.233.3"
SubNet: 192.168.233.0/24
PublicKey: 牢喨粷詸衭譛浾蘹櫠砙杹蟫瑳叩刋橋経挵蘀
PresharedKey: 竅琚喫従痸告烈兇厕趭萨假蔛瀇譄施烸蝫瘀
EndPoint: ""
AllowedIPs: ["192.168.233.3/32"]
MTU: 752
KeepAliveSeconds: 0
AllowTrans: false
```
```

View File

@@ -23,6 +23,7 @@ type Peer struct {
IP string `yaml:"IP"`
SubNet string `yaml:"SubNet"`
PublicKey string `yaml:"PublicKey"`
PresharedKey string `yaml:"PresharedKey"`
EndPoint string `yaml:"EndPoint"`
AllowedIPs []string `yaml:"AllowedIPs"`
KeepAliveSeconds int64 `yaml:"KeepAliveSeconds"`

View File

@@ -6,7 +6,6 @@ import (
"errors"
"hash/crc64"
"net"
"sync/atomic"
"github.com/fumiama/WireGold/helper"
blake2b "github.com/fumiama/blake2b-simd"
@@ -17,7 +16,7 @@ import (
type Packet struct {
// TeaTypeDataSZ len(Data)
// 高 4 位指定加密所用 tea key
// 高 4-16 位是随机值
// 高 4-16 位是递增值, 用于预共享密钥验证 additionalData
// 不得超过 65507-head 字节
TeaTypeDataSZ uint32
// Proto 详见 head
@@ -109,18 +108,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
return
}
var counter uint32
// Marshal 将自身数据编码为 []byte
// offset 必须为 8 的倍数,表示偏移的 8 位
func (p *Packet) Marshal(src net.IP, teatype uint8, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz uint32, offset uint16, dontfrag, hasmore bool) ([]byte, func()) {
p.TTL--
if p.TTL == 0 {
return nil, nil
}
if src != nil {
p.TeaTypeDataSZ = uint32(teatype)<<28 | (atomic.AddUint32(&counter, 1)<<16)&0x0fff0000 | datasz
p.TeaTypeDataSZ = uint32(teatype)<<28 | (uint32(additional&0x0fff) << 16) | datasz&0xffff
p.Src = src
offset &= 0x1fff
if dontfrag {
@@ -171,6 +168,11 @@ func (p *Packet) IsVaildHash() bool {
return sum == p.Hash
}
// AdditionalData 获得 packet 的 additionalData
func (p *Packet) AdditionalData() uint16 {
return uint16((p.TeaTypeDataSZ >> 16) & 0x0fff)
}
// Put 将自己放回池中
func (p *Packet) Put() {
PutPacket(p)

View File

@@ -1,5 +1,10 @@
package link
import (
"crypto/rand"
"encoding/binary"
)
// Encode 使用 TEA 加密
func (l *Link) Encode(teatype uint8, b []byte) (eb []byte) {
if b == nil || teatype >= 16 {
@@ -29,3 +34,34 @@ func (l *Link) Decode(teatype uint8, b []byte) (db []byte) {
db = l.key[teatype].Decrypt(b)
return
}
// EncodePreshared 使用 chacha20poly1305 加密
func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) {
nsz := l.aead.NonceSize()
// Select a random nonce, and leave capacity for the ciphertext.
nonce := make([]byte, nsz, nsz+len(b)+l.aead.Overhead())
_, err := rand.Read(nonce)
if err != nil {
return
}
// Encrypt the message and append the ciphertext to the nonce.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
eb = l.aead.Seal(nonce, nonce, b, buf[:])
return
}
// DecodePreshared 使用 chacha20poly1305 解密
func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) {
nsz := l.aead.NonceSize()
if len(b) < nsz { // ciphertext too short
return
}
// Split nonce and ciphertext.
nonce, ciphertext := b[:nsz], b[nsz:]
// Decrypt the message and check it wasn't tampered with.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
db, _ = l.aead.Open(nil, nonce, ciphertext, buf[:])
return
}

View File

@@ -1,6 +1,7 @@
package link
import (
"crypto/cipher"
"errors"
"net"
@@ -14,6 +15,8 @@ import (
type Link struct {
// peer 的公钥
pubk *[32]byte
// 发包计数, 分片算一个
sendcount uintptr
// 收到的包的队列
// 没有下层 nic 时
// 包会分发到此
@@ -26,6 +29,8 @@ type Link struct {
allowedips []*net.IPNet
// 连接所用对称加密密钥
key []tea.TEA
// 连接所用预共享密钥
aead cipher.AEAD
// 本机信息
me *Me
// 连接的状态,详见下方 const

View File

@@ -43,40 +43,48 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) {
sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data)
if r > 0 {
logrus.Warnln("[link] packet from endpoint", addr, "is smaller than it declared: drop it")
logrus.Warnln("[listen] packet from endpoint", addr, "is smaller than it declared: drop it")
packet.Put()
continue
}
p, ok := m.IsInPeer(packet.Src.String())
logrus.Debugln("[link] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
// logrus.Debugln("[link] recv:", hex.EncodeToString(lbf))
logrus.Debugln("[listen] recv from endpoint", addr, "src", packet.Src, "dst", packet.Dst)
// logrus.Debugln("[listen] recv:", hex.EncodeToString(lbf))
if !ok {
logrus.Warnln("[link] packet from", packet.Src, "to", packet.Dst, "is refused")
logrus.Warnln("[listen] packet from", packet.Src, "to", packet.Dst, "is refused")
packet.Put()
continue
}
if p.endpoint == nil || p.endpoint.String() != addr.String() {
logrus.Infoln("[link] set endpoint of peer", p.peerip, "to", addr.String())
logrus.Infoln("[listen] set endpoint of peer", p.peerip, "to", addr.String())
p.endpoint = addr
}
switch {
case p.IsToMe(packet.Dst):
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data)
if !packet.IsVaildHash() {
logrus.Debugln("[link] drop invalid packet")
logrus.Debugln("[listen] drop invalid hash packet")
packet.Put()
continue
}
if p.aead != nil {
packet.Data = p.DecodePreshared(packet.AdditionalData(), packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] drop invalid additional data packet")
packet.Put()
continue
}
}
switch packet.Proto {
case head.ProtoHello:
switch p.status {
case LINK_STATUS_DOWN:
n, err = p.WriteAndPut(head.NewPacket(head.ProtoHello, m.SrcPort(), p.peerip, m.DstPort(), nil), false)
if err == nil {
logrus.Debugln("[link] send", n, "bytes hello ack packet")
logrus.Debugln("[listen] send", n, "bytes hello ack packet")
p.status = LINK_STATUS_HALFUP
} else {
logrus.Errorln("[link] send hello ack packet error:", err)
logrus.Errorln("[listen] send hello ack packet error:", err)
}
case LINK_STATUS_HALFUP:
p.status = LINK_STATUS_UP
@@ -84,47 +92,47 @@ func (m *Me) listenthread(conn *net.UDPConn, mu *sync.Mutex) {
}
packet.Put()
case head.ProtoNotify:
logrus.Infoln("[link] recv notify from", packet.Src)
logrus.Infoln("[listen] recv notify from", packet.Src)
go p.onNotify(packet.Data)
packet.Put()
case head.ProtoQuery:
logrus.Infoln("[link] recv query from", packet.Src)
logrus.Infoln("[listen] recv query from", packet.Src)
go p.onQuery(packet.Data)
packet.Put()
case head.ProtoData:
if p.pipe != nil {
p.pipe <- packet
logrus.Debugln("[link] deliver to pipe of", p.peerip)
logrus.Debugln("[listen] deliver to pipe of", p.peerip)
} else {
m.nic.Write(packet.Data)
logrus.Debugln("[link] deliver", len(packet.Data), "bytes data to nic")
logrus.Debugln("[listen] deliver", len(packet.Data), "bytes data to nic")
packet.Put()
}
default:
logrus.Warnln("[link] recv unknown proto:", packet.Proto)
logrus.Warnln("[listen] recv unknown proto:", packet.Proto)
packet.Put()
}
case p.Accept(packet.Dst):
if !p.allowtrans {
logrus.Warnln("[link] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
logrus.Warnln("[listen] refused to trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
packet.Put()
continue
}
// 转发
lnk := m.router.NextHop(packet.Dst.String())
if lnk == nil {
logrus.Warnln("[link] transfer drop packet: nil nexthop")
logrus.Warnln("[listen] transfer drop packet: nil nexthop")
packet.Put()
continue
}
n, err = lnk.WriteAndPut(packet, true)
if err == nil {
logrus.Debugln("[link] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
logrus.Debugln("[listen] trans", n, "bytes packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)))
} else {
logrus.Errorln("[link] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
logrus.Errorln("[listen] trans packet to", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "err:", err)
}
default:
logrus.Warnln("[link] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
logrus.Warnln("[listen] packet dst", packet.Dst.String()+":"+strconv.Itoa(int(packet.DstPort)), "is not in peers")
packet.Put()
}
}

View File

@@ -16,14 +16,14 @@ import (
// 以秒为单位,小于等于 0 不发送
func (l *Link) keepAlive(dur int64) {
if dur > 0 {
logrus.Infoln("[link.nat] start to keep alive")
logrus.Infoln("[nat] start to keep alive")
t := time.NewTicker(time.Second * time.Duration(dur))
for range t.C {
n, err := l.WriteAndPut(head.NewPacket(head.ProtoHello, l.me.srcport, l.peerip, l.me.dstport, nil), false)
if err == nil {
logrus.Infoln("[link] send", n, "bytes keep alive packet")
logrus.Infoln("[nat] send", n, "bytes keep alive packet")
} else {
logrus.Errorln("[link] send keep alive packet error:", err)
logrus.Errorln("[nat] send keep alive packet error:", err)
}
}
}
@@ -37,7 +37,7 @@ func (l *Link) onNotify(packet []byte) {
notify := make(head.Notify, 32)
err := json.Unmarshal(packet, &notify)
if err != nil {
logrus.Errorln("[notify] json unmarshal err:", err)
logrus.Errorln("[nat] notify json unmarshal err:", err)
return
}
// 2. endpoint注册
@@ -50,12 +50,12 @@ func (l *Link) onNotify(packet []byte) {
if ok {
if p.endpoint.String() != ep {
p.endpoint = addr
logrus.Infoln("[notify] set ep of peer", peer, "to", ep)
logrus.Infoln("[nat] notify set ep of peer", peer, "to", ep)
}
continue
}
}
logrus.Debugln("[notify] drop invalid peer:", peer, "ep:", ep)
logrus.Debugln("[nat] notify drop invalid peer:", peer, "ep:", ep)
}
}
@@ -69,7 +69,7 @@ func (l *Link) onQuery(packet []byte) {
var peers head.Query
err := json.Unmarshal(packet, &peers)
if err != nil {
logrus.Errorln("[qurey] json unmarshal err:", err)
logrus.Errorln("[nat] query json unmarshal err:", err)
return
}
@@ -84,7 +84,7 @@ func (l *Link) onQuery(packet []byte) {
}
}
if len(notify) > 0 {
logrus.Infoln("[query] wrap", len(notify), "notify")
logrus.Infoln("[nat] query wrap", len(notify), "notify")
w := helper.SelectWriter()
json.NewEncoder(w).Encode(&notify)
l.WriteAndPut(head.NewPacket(head.ProtoNotify, l.me.srcport, l.peerip, l.me.dstport, w.Bytes()), false)
@@ -103,10 +103,10 @@ func (l *Link) sendquery(tick time.Duration, peers ...string) {
}
t := time.NewTicker(tick)
for range t.C {
logrus.Infoln("[query] send query to", l.peerip)
logrus.Infoln("[nat] query send query to", l.peerip)
_, err = l.WriteAndPut(head.NewPacket(head.ProtoQuery, l.me.srcport, l.peerip, l.me.dstport, data), false)
if err != nil {
logrus.Errorln("[query] write err:", err)
logrus.Errorln("[nat] query write err:", err)
}
}
}

View File

@@ -8,6 +8,7 @@ import (
curve "github.com/fumiama/go-x25519"
tea "github.com/fumiama/gofastTEA"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/chacha20poly1305"
)
type PeerConfig struct {
@@ -15,6 +16,7 @@ type PeerConfig struct {
EndPoint string
AllowedIPs, Querys []string
PubicKey *[32]byte
PresharedKey *[32]byte
KeepAliveDur, QueryTick int64
MTU uint16
AllowTrans, NoPipe bool
@@ -52,6 +54,13 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
}
}
}
if cfg.PresharedKey != nil {
var err error
l.aead, err = chacha20poly1305.NewX(cfg.PresharedKey[:])
if err != nil {
panic(err)
}
}
if cfg.EndPoint != "" {
e, err := net.ResolveUDPAddr("udp", cfg.EndPoint)
if err != nil {

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/fumiama/WireGold/gold/head"
"github.com/sirupsen/logrus"
@@ -12,16 +13,23 @@ import (
// WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
teatype := uint8(rand.Intn(16))
sndcnt := atomic.AddUintptr(&l.sendcount, 1)
if len(p.Data) <= int(l.mtu) {
if !istransfer {
p.FillHash()
if l.aead != nil {
p.Data = l.EncodePreshared(uint16(sndcnt), p.Data)
}
p.Data = l.Encode(teatype, p.Data)
}
defer p.Put()
return l.write(p, teatype, uint32(len(p.Data)), 0, istransfer, false)
return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false)
}
if !istransfer {
p.FillHash()
if l.aead != nil {
p.Data = l.EncodePreshared(uint16(sndcnt), p.Data)
}
p.Data = l.Encode(teatype, p.Data)
}
data := p.Data
@@ -31,9 +39,9 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
packet := head.SelectPacket()
*packet = *p
for ; int(totl)-i > int(l.mtu); i += int(l.mtu) {
logrus.Debugln("[link] split frag", i, ":", i+int(l.mtu), ", remain:", int(totl)-i-int(l.mtu))
logrus.Debugln("[send] split frag", i, ":", i+int(l.mtu), ", remain:", int(totl)-i-int(l.mtu))
packet.Data = data[:int(l.mtu)]
cnt, err := l.write(packet, teatype, totl, uint16(i>>3), istransfer, true)
cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, true)
n += cnt
if err != nil {
return n, err
@@ -43,33 +51,33 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
}
packet.Put()
p.Data = data
cnt, err := l.write(p, teatype, totl, uint16(i>>3), istransfer, false)
cnt, err := l.write(p, teatype, uint16(sndcnt), totl, uint16(i>>3), istransfer, false)
p.Put()
n += cnt
return n, err
}
// write 向 peer 发一个包
func (l *Link) write(p *head.Packet, teatype uint8, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) {
func (l *Link) write(p *head.Packet, teatype uint8, additional uint16, datasz uint32, offset uint16, istransfer, hasmore bool) (n int, err error) {
var d []byte
var cl func()
if istransfer {
if p.Flags&0x4000 == 0x4000 && len(p.Data) > int(l.mtu) {
return len(p.Data), errors.New("drop dont fragmnet big trans packet")
}
d, cl = p.Marshal(nil, teatype, 0, 0, false, false)
d, cl = p.Marshal(nil, teatype, additional, 0, 0, false, false)
} else {
d, cl = p.Marshal(l.me.me, teatype, datasz, offset, false, hasmore)
d, cl = p.Marshal(l.me.me, teatype, additional, datasz, offset, false, hasmore)
}
if d == nil {
return 0, errors.New("[link] ttl exceeded")
return 0, errors.New("[send] ttl exceeded")
}
if err == nil {
peerep := l.endpoint
if peerep == nil {
return 0, errors.New("[link] nil endpoint of " + p.Dst.String())
return 0, errors.New("[send] nil endpoint of " + p.Dst.String())
}
logrus.Debugln("[link] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
logrus.Debugln("[send] write", len(d), "bytes data from ep", l.me.myep.LocalAddr(), "to", peerep, "offset:", fmt.Sprintf("%04x", offset))
n, err = l.me.myep.WriteToUDP(d, peerep)
cl()
}

15
main.go
View File

@@ -2,6 +2,7 @@ package main
import (
"bytes"
"crypto/rand"
"flag"
"fmt"
"os"
@@ -19,6 +20,7 @@ import (
func main() {
help := flag.Bool("h", false, "display this help")
gen := flag.Bool("g", false, "generate key pair")
pshgen := flag.Bool("pg", false, "generate preshared key")
showp := flag.Bool("p", false, "show my publickey")
file := flag.String("c", "config.yaml", "specify conf file")
debug := flag.Bool("d", false, "print debug logs")
@@ -50,6 +52,19 @@ func main() {
fmt.Println("PrivateKey:", helper.BytesToString(prvk[:57]))
os.Exit(0)
}
if *pshgen {
var buf [32]byte
_, err := rand.Read(buf[:])
if err != nil {
panic(err)
}
pshk, err := base14.UTF16BE2UTF8(base14.Encode(buf[:]))
if err != nil {
panic(err)
}
fmt.Println("PresharedKey:", helper.BytesToString(pshk[:57]))
os.Exit(0)
}
if *logfile != "-" {
f, err := os.Create(*logfile)
if err != nil {

View File

@@ -107,7 +107,19 @@ func (wg *WG) init(srcport, dstport uint16) {
}
n := copy(peerkey[:], base14.Decode(k))
if n != 32 {
panic("peer public key length is not 32")
panic("peer public key length < 32")
}
var pshk *[32]byte
if peer.PresharedKey != "" {
k, err := base14.UTF82UTF16BE(helper.StringToBytes(peer.PresharedKey + suffix32))
if err != nil {
panic(err)
}
pshk = &[32]byte{}
n := copy(pshk[:], base14.Decode(k))
if n != 32 {
panic("peer preshared key length < 32")
}
}
wg.me.AddPeer(&link.PeerConfig{
PeerIP: peer.IP,
@@ -115,6 +127,7 @@ func (wg *WG) init(srcport, dstport uint16) {
AllowedIPs: peer.AllowedIPs,
Querys: peer.QueryList,
PubicKey: &peerkey,
PresharedKey: pshk,
KeepAliveDur: peer.KeepAliveSeconds,
QueryTick: peer.QuerySeconds,
MTU: uint16(peer.MTU),