1
0
mirror of https://github.com/fumiama/WireGold.git synced 2026-06-30 07:40:25 +08:00

perf: remove tea encryption

This commit is contained in:
源文雨
2024-07-11 22:31:44 +09:00
parent c0bd86d1bb
commit e115098344
14 changed files with 243 additions and 117 deletions

View File

@@ -1,45 +1,89 @@
package link
import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"math/bits"
mrand "math/rand"
)
// Encode 使用 TEA 加密
func (l *Link) Encode(teatype uint8, b []byte) (eb []byte) {
if b == nil || teatype >= 16 {
return
func (l *Link) randkeyidx() uint8 {
if l.keys[1] == nil {
return 0
}
if l.key == nil {
eb = b
return
return uint8(mrand.Intn(32))
}
func mixkeys(k1, k2 []byte) []byte {
if len(k1) != 32 || len(k2) != 32 {
panic("unexpected key len")
}
k := make([]byte, 64)
for i := range k1 {
k1i, k2i := i, 31-i
k1v, k2v := k1[k1i], k2[k2i]
binary.LittleEndian.PutUint16(
k[i*2:(i+1)*2],
expandkeyunit(k1v, k2v),
)
}
return k
}
func expandkeyunit(v1, v2 byte) (v uint16) {
v1s, v2s := uint16(v1), uint16(bits.Reverse8(v2))
for i := 0; i < 8; i++ {
v |= v1s & (1 << (i * 2))
v1s <<= 1
}
for i := 0; i < 8; i++ {
v2s <<= 1
v |= v2s & (2 << (i * 2))
}
// 在此处填写加密逻辑密钥是l.key输入是b输出是eb
// 不用写return直接赋值给eb即可
eb = l.key[teatype].Encrypt(b)
return
}
// Decode 使用 TEA 解
func (l *Link) Decode(teatype uint8, b []byte) (db []byte) {
if b == nil || teatype >= 16 {
// Encode 使用 xchacha20poly1305 和密钥序列加
func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) {
if b == nil || teatype >= 32 {
return
}
if l.key == nil {
if l.keys[0] == nil {
eb = make([]byte, len(b))
copy(eb, b)
return
}
aead := l.keys[teatype]
if aead == nil {
return
}
eb = encode(aead, additional, b)
return
}
// Decode 使用 xchacha20poly1305 和密钥序列解密
func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) {
if b == nil || teatype >= 32 {
return
}
if l.keys[0] == nil {
db = b
return
}
// 在此处填写解密逻辑密钥是l.key输入是b输出是db
// 不用写return直接赋值给db即可
db = l.key[teatype].Decrypt(b)
aead := l.keys[teatype]
if aead == nil {
return
}
db = decode(aead, additional, b)
return
}
// EncodePreshared 使用 xchacha20poly1305 加密
func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) {
nsz := l.aead.NonceSize()
// encode 使用 xchacha20poly1305 加密
func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) {
nsz := aead.NonceSize()
// Select a random nonce, and leave capacity for the ciphertext.
nonce := make([]byte, nsz, nsz+len(b)+l.aead.Overhead())
nonce := make([]byte, nsz, nsz+len(b)+aead.Overhead())
_, err := rand.Read(nonce)
if err != nil {
return
@@ -47,13 +91,13 @@ func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) {
// Encrypt the message and append the ciphertext to the nonce.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
eb = l.aead.Seal(nonce, nonce, b, buf[:])
eb = aead.Seal(nonce, nonce, b, buf[:])
return
}
// DecodePreshared 使用 xchacha20poly1305 解密
func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) {
nsz := l.aead.NonceSize()
// decode 使用 xchacha20poly1305 解密
func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) {
nsz := aead.NonceSize()
if len(b) < nsz { // ciphertext too short
return
}
@@ -62,7 +106,7 @@ func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) {
// Decrypt the message and check it wasn't tampered with.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
db, _ = l.aead.Open(nil, nonce, ciphertext, buf[:])
db, _ = aead.Open(nil, nonce, ciphertext, buf[:])
return
}

View File

@@ -3,6 +3,8 @@ package link
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"io"
"testing"
@@ -32,20 +34,47 @@ func TestXOR(t *testing.T) {
}
func TestXChacha20(t *testing.T) {
l := Link{}
k := make([]byte, 32)
_, err := rand.Read(k)
if err != nil {
t.Fatal(err)
}
l.aead, err = chacha20poly1305.NewX(k)
aead, err := chacha20poly1305.NewX(k)
if err != nil {
t.Fatal(err)
}
data := []byte("12345678")
for i := uint64(0); i < 100000; i++ {
if !bytes.Equal(l.DecodePreshared(uint16(i), l.EncodePreshared(uint16(i), data)), data) {
if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) {
t.Fatal("unexpected preshared at", i, "addt", uint16(i))
}
}
}
func TestExpandKeyUnit(t *testing.T) {
k1 := byte(0b10001010)
k2 := byte(0b10111010) // rev 01011101
v := expandkeyunit(k1, k2) // x1x0x0x0x1x0x1x0 | 0x1x0x1x1x1x0x1x = 0110001011100110
if v != 0b0110001011100110 {
buf := [2]byte{}
binary.BigEndian.PutUint16(buf[:], v)
t.Fatal(hex.EncodeToString(buf[:]))
}
}
func TestMixKeys(t *testing.T) {
k1, _ := hex.DecodeString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
k2, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000000")
k := mixkeys(k1, k2)
kexp, _ := hex.DecodeString("55555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555")
if !bytes.Equal(k, kexp) {
t.Fatal(hex.EncodeToString(k))
}
k1, _ = hex.DecodeString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
k2, _ = hex.DecodeString("deadbeef1239876540deadbeef1239876540deadbeef1239876540abcdef4567")
k = mixkeys(k1, k2)
kexp, _ = hex.DecodeString("2ca9188d3ebb4a9f22e34d4479d857fca48390253ebbe23f22cbcf6e59507ddc06a9b08794316abfa26b67cedb7a5d542c8912adb493c0352aebe76e73dadf7e")
if !bytes.Equal(k, kexp) {
t.Fatal(hex.EncodeToString(k))
}
}

View File

@@ -4,11 +4,11 @@ import (
"crypto/cipher"
"errors"
"net"
"sync/atomic"
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/helper"
base14 "github.com/fumiama/go-base16384"
tea "github.com/fumiama/gofastTEA"
)
// Link 是本机到 peer 的连接抽象
@@ -27,10 +27,8 @@ type Link struct {
endpoint *net.UDPAddr
// 本机允许接收/发送的 ip 网段
allowedips []*net.IPNet
// 连接所用对称加密密钥
key []tea.TEA
// 连接所用预共享密钥
aead cipher.AEAD
// 连接所用对称加密密钥
keys [32]cipher.AEAD
// 本机信息
me *Me
// 连接的状态,详见下方 const
@@ -84,3 +82,7 @@ func (l *Link) String() (n string) {
}
return
}
func (l *Link) incgetsndcnt() uintptr {
return atomic.AddUintptr(&l.sendcount, 1)
}

View File

@@ -2,9 +2,11 @@ package link
import (
"bytes"
"errors"
"io"
"net"
"net/netip"
"os"
"runtime"
"strconv"
"sync"
@@ -48,7 +50,18 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
}
logrus.Debugln("[listen] lock index", i)
lbf := listenbuff[i*65536 : (i+1)*65536]
err = conn.SetDeadline(time.Now().Add(time.Second))
if err != nil {
logrus.Warnln("[listen] set ddl err:", err)
}
n, addr, err := conn.ReadFromUDP(lbf)
if m.loop == nil {
logrus.Warnln("[listen] quit listening")
return
}
if errors.Is(err, os.ErrDeadlineExceeded) {
err = nil
}
if err != nil {
logrus.Warnln("[listen] read from udp err, reconnect:", err)
conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.udpep.String())))
@@ -102,15 +115,12 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
}
switch {
case p.IsToMe(packet.Dst):
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data)
if p.aead != nil {
addt := packet.AdditionalData()
packet.Data = p.DecodePreshared(addt, packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] @", index, "drop invalid preshared packet, addt:", addt)
packet.Put()
return
}
addt := packet.AdditionalData()
packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
if packet.Data == nil {
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt)
packet.Put()
return
}
if p.usezstd {
dec, _ := zstd.NewReader(bytes.NewReader(packet.Data))

View File

@@ -121,9 +121,25 @@ func (m *Me) MTU() uint16 {
return m.mtu
}
func (m *Me) CloseNIC() error {
m.nic.Down()
return m.nic.Close()
func (m *Me) EndPoint() net.Addr {
return m.udpep
}
func (m *Me) Close() error {
m.loop = nil
m.connections = nil
_ = m.udpconn.Close()
m.udpconn = nil
m.router = nil
m.recving.Destroy()
m.recving = nil
m.recved.Destroy()
m.recved = nil
if m.nic != nil {
m.nic.Down()
return m.nic.Close()
}
return nil
}
func (m *Me) Write(packet []byte) (n int, err error) {

View File

@@ -6,7 +6,6 @@ import (
"github.com/fumiama/WireGold/gold/head"
curve "github.com/fumiama/go-x25519"
tea "github.com/fumiama/gofastTEA"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/chacha20poly1305"
)
@@ -48,21 +47,28 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) {
if !cfg.NoPipe {
l.pipe = make(chan *head.Packet, 32)
}
var k, p []byte
if cfg.PubicKey != nil {
c := curve.Get(m.privKey[:])
k, err := c.Shared(cfg.PubicKey)
if err == nil {
l.key = make([]tea.TEA, 16)
for i := range l.key {
l.key[i] = tea.NewTeaCipherLittleEndian(k[i : 16+i])
}
}
k, _ = curve.Get(m.privKey[:]).Shared(cfg.PubicKey)
}
if cfg.PresharedKey != nil {
p = cfg.PresharedKey[:]
}
if len(k) == 32 {
var err error
l.aead, err = chacha20poly1305.NewX(cfg.PresharedKey[:])
if err != nil {
panic(err)
if len(p) == 32 {
mixk := mixkeys(k, p)
for i := range k {
l.keys[i], err = chacha20poly1305.NewX(mixk[i : i+32])
if err != nil {
panic(err)
}
}
} else {
l.keys[0], err = chacha20poly1305.NewX(k)
if err != nil {
panic(err)
}
}
}
if cfg.EndPoint != "" {

View File

@@ -10,7 +10,7 @@ import (
)
type Router struct {
sync.RWMutex
mu sync.RWMutex
// map[cidr]*Link
table map[string]*Link
list []*net.IPNet
@@ -35,10 +35,10 @@ func (l *Link) IsToMe(ip net.IP) bool {
// SetDefault 设置默认网关
func (r *Router) SetDefault(l *Link) {
defnet := &net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
r.Lock()
r.mu.Lock()
r.list[len(r.list)-1] = defnet
r.table[defnet.String()] = l
r.Unlock()
r.mu.Unlock()
}
// NextHop 得到前往 ip 的下一跳的 link
@@ -56,8 +56,8 @@ func (r *Router) NextHop(ip string) (l *Link) {
// 遍历 r.table得到正确的下一跳
// 注意使用 r.mu 读写锁避免竞争
r.RLock()
defer r.RUnlock()
r.mu.RLock()
defer r.mu.RUnlock()
for _, c := range r.list {
if c.Contains(ipb) {
@@ -75,7 +75,7 @@ func (r *Router) NextHop(ip string) (l *Link) {
// SetItem 添加一条表项
func (r *Router) SetItem(ip *net.IPNet, l *Link) {
r.Lock()
r.mu.Lock()
// 从第一条表项开始匹配
for i := 0; i < len(r.list); i++ {
if r.list[i].Contains(ip.IP) {
@@ -94,7 +94,7 @@ func (r *Router) SetItem(ip *net.IPNet, l *Link) {
break
}
}
r.Unlock()
r.mu.Unlock()
}
func isSubnetBcast(ip net.IP, subnet *net.IPNet) bool {

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"math/rand"
"sync/atomic"
"github.com/fumiama/WireGold/gold/head"
"github.com/fumiama/WireGold/helper"
@@ -18,15 +17,15 @@ import (
// WriteAndPut 向 peer 发包并将包放回缓存池
func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
defer p.Put()
teatype := uint8(rand.Intn(16))
sndcnt := atomic.AddUintptr(&l.sendcount, 1)
teatype := l.randkeyidx()
sndcnt := uint16(l.incgetsndcnt())
mtu := l.mtu
if l.mturandomrange > 0 {
mtu -= uint16(rand.Intn(int(l.mturandomrange)))
}
logrus.Debugln("[send] mtu:", mtu, ", count:", sndcnt, ", additional data:", uint16(sndcnt)&0x0fff)
logrus.Debugln("[send] mtu:", mtu, ", addt:", uint16(sndcnt)&0x0fff, ", key index:", teatype)
if !istransfer {
l.encrypt(p, uint16(sndcnt), teatype)
l.encrypt(p, sndcnt, teatype)
}
delta := (int(mtu) - 60) & 0x0000fff8
if delta <= 0 {
@@ -34,7 +33,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
delta = 8
}
if len(p.Data) <= delta {
return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false)
return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false)
}
if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta {
return 0, errors.New("drop don't fragmnet big trans packet")
@@ -48,7 +47,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
for ; int(totl)-pos > delta; pos += delta {
logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", int(totl)-pos-delta)
packet.Data = data[:delta]
cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, true)
cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true)
n += cnt
if err != nil {
return n, err
@@ -60,7 +59,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
if len(data) > 0 {
p.Data = data
cnt := 0
cnt, err = l.write(p, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, false)
cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false)
n += cnt
}
return n, err
@@ -78,12 +77,8 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) {
p.Data = w.Bytes()
logrus.Debugln("[send] data len after zstd:", len(p.Data))
}
if l.aead != nil {
p.Data = l.EncodePreshared(sndcnt&0x0fff, p.Data)
logrus.Debugln("[send] data len after xchacha20:", len(p.Data))
}
p.Data = l.Encode(teatype, p.Data)
logrus.Debugln("[send] data len after tea:", len(p.Data))
p.Data = l.Encode(teatype, sndcnt&0x07ff, p.Data)
logrus.Debugln("[send] data len after xchacha20:", len(p.Data), "addt:", sndcnt)
}
// write 向 peer 发一个包