diff --git a/go.mod b/go.mod index 1bf8e52..6106dd2 100644 --- a/go.mod +++ b/go.mod @@ -7,16 +7,15 @@ require ( github.com/fumiama/blake2b-simd v0.0.0-20220412110131-4481822068bb github.com/fumiama/go-base16384 v1.7.0 github.com/fumiama/go-x25519 v1.0.0 - github.com/fumiama/gofastTEA v0.0.10 github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac - github.com/klauspost/compress v1.16.7 + github.com/klauspost/compress v1.17.9 github.com/sirupsen/logrus v1.9.3 - golang.org/x/crypto v0.11.1-0.20230731181441-edc325d13aa9 + golang.org/x/crypto v0.25.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index bcb05cc..339497e 100644 --- a/go.sum +++ b/go.sum @@ -9,14 +9,12 @@ github.com/fumiama/go-base16384 v1.7.0 h1:6fep7XPQWxRlh4Hu+KsdH+6+YdUp+w6CwRXtMW github.com/fumiama/go-base16384 v1.7.0/go.mod h1:OEn+947GV5gsbTAnyuUW/SrfxJYUdYupSIQXOuGOcXM= github.com/fumiama/go-x25519 v1.0.0 h1:hiGg9EhseVmGCc8T1jECVkj8Keu/aJ1ZK05RM8Vuavo= github.com/fumiama/go-x25519 v1.0.0/go.mod h1:8VOhfyGZzw4IUs4nCjQFqW9cA3V/QpSCtP3fo2dLNg4= -github.com/fumiama/gofastTEA v0.0.10 h1:JJJ+brWD4kie+mmK2TkspDXKzqq0IjXm89aGYfoGhhQ= -github.com/fumiama/gofastTEA v0.0.10/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac h1:A/5A0rODsg+EQHH61Ew5mMUtDpRXaSNqHhPvW+fN4C4= github.com/fumiama/water v0.0.0-20211231134027-da391938d6ac/go.mod h1:BBnNY9PwK+UUn4trAU+H0qsMEypm7+3Bj1bVFuJItlo= github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0 h1:WfrSFlIlCAtg6Rt2IGna0HhJYSDE45YVHiYqO4wwsEw= github.com/fumiama/wintun v0.0.0-20211229152851-8bc97c8034c0/go.mod h1:dPOG7Af/ArO62RgBz2JJTNFByBn/IXWLo/1kZKcLSe8= -github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= -github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -27,19 +25,19 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.11.1-0.20230731181441-edc325d13aa9 h1:fD3wQV2d6EmGo8zyauOQlRGUy3CwlspiqHJbiEO0nrc= -golang.org/x/crypto v0.11.1-0.20230731181441-edc325d13aa9/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/gold/head/packet.go b/gold/head/packet.go index 68617e9..47adb72 100644 --- a/gold/head/packet.go +++ b/gold/head/packet.go @@ -15,8 +15,8 @@ import ( // Packet 是发送和接收的最小单位 type Packet struct { // TeaTypeDataSZ len(Data) - // 高 4 位指定加密所用 tea key - // 高 4-16 位是递增值, 用于预共享密钥验证 additionalData + // 高 5 位指定加密所用 key index + // 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData // 不得超过 65507-head 字节 TeaTypeDataSZ uint32 // Proto 详见 head @@ -118,7 +118,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui } if src != nil { - p.TeaTypeDataSZ = uint32(teatype)<<28 | (uint32(additional&0x0fff) << 16) | datasz&0xffff + p.TeaTypeDataSZ = uint32(teatype)<<27 | (uint32(additional&0x07ff) << 16) | datasz&0xffff p.Src = src offset &= 0x1fff if dontfrag { @@ -171,7 +171,7 @@ func (p *Packet) IsVaildHash() bool { // AdditionalData 获得 packet 的 additionalData func (p *Packet) AdditionalData() uint16 { - return uint16((p.TeaTypeDataSZ >> 16) & 0x0fff) + return uint16((p.TeaTypeDataSZ >> 16) & 0x07ff) } // Put 将自己放回池中 diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 9cb6843..303ff40 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -1,45 +1,89 @@ package link import ( + "crypto/cipher" "crypto/rand" "encoding/binary" + "math/bits" + mrand "math/rand" ) -// Encode 使用 TEA 加密 -func (l *Link) Encode(teatype uint8, b []byte) (eb []byte) { - if b == nil || teatype >= 16 { - return +func (l *Link) randkeyidx() uint8 { + if l.keys[1] == nil { + return 0 } - if l.key == nil { - eb = b - return + return uint8(mrand.Intn(32)) +} + +func mixkeys(k1, k2 []byte) []byte { + if len(k1) != 32 || len(k2) != 32 { + panic("unexpected key len") + } + k := make([]byte, 64) + for i := range k1 { + k1i, k2i := i, 31-i + k1v, k2v := k1[k1i], k2[k2i] + binary.LittleEndian.PutUint16( + k[i*2:(i+1)*2], + expandkeyunit(k1v, k2v), + ) + } + return k +} + +func expandkeyunit(v1, v2 byte) (v uint16) { + v1s, v2s := uint16(v1), uint16(bits.Reverse8(v2)) + for i := 0; i < 8; i++ { + v |= v1s & (1 << (i * 2)) + v1s <<= 1 + } + for i := 0; i < 8; i++ { + v2s <<= 1 + v |= v2s & (2 << (i * 2)) } - // 在此处填写加密逻辑,密钥是l.key,输入是b,输出是eb - // 不用写return,直接赋值给eb即可 - eb = l.key[teatype].Encrypt(b) return } -// Decode 使用 TEA 解密 -func (l *Link) Decode(teatype uint8, b []byte) (db []byte) { - if b == nil || teatype >= 16 { +// Encode 使用 xchacha20poly1305 和密钥序列加密 +func (l *Link) Encode(teatype uint8, additional uint16, b []byte) (eb []byte) { + if b == nil || teatype >= 32 { return } - if l.key == nil { + if l.keys[0] == nil { + eb = make([]byte, len(b)) + copy(eb, b) + return + } + aead := l.keys[teatype] + if aead == nil { + return + } + eb = encode(aead, additional, b) + return +} + +// Decode 使用 xchacha20poly1305 和密钥序列解密 +func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte) { + if b == nil || teatype >= 32 { + return + } + if l.keys[0] == nil { db = b return } - // 在此处填写解密逻辑,密钥是l.key,输入是b,输出是db - // 不用写return,直接赋值给db即可 - db = l.key[teatype].Decrypt(b) + aead := l.keys[teatype] + if aead == nil { + return + } + db = decode(aead, additional, b) return } -// EncodePreshared 使用 xchacha20poly1305 加密 -func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) { - nsz := l.aead.NonceSize() +// encode 使用 xchacha20poly1305 加密 +func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) { + nsz := aead.NonceSize() // Select a random nonce, and leave capacity for the ciphertext. - nonce := make([]byte, nsz, nsz+len(b)+l.aead.Overhead()) + nonce := make([]byte, nsz, nsz+len(b)+aead.Overhead()) _, err := rand.Read(nonce) if err != nil { return @@ -47,13 +91,13 @@ func (l *Link) EncodePreshared(additional uint16, b []byte) (eb []byte) { // Encrypt the message and append the ciphertext to the nonce. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) - eb = l.aead.Seal(nonce, nonce, b, buf[:]) + eb = aead.Seal(nonce, nonce, b, buf[:]) return } -// DecodePreshared 使用 xchacha20poly1305 解密 -func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) { - nsz := l.aead.NonceSize() +// decode 使用 xchacha20poly1305 解密 +func decode(aead cipher.AEAD, additional uint16, b []byte) (db []byte) { + nsz := aead.NonceSize() if len(b) < nsz { // ciphertext too short return } @@ -62,7 +106,7 @@ func (l *Link) DecodePreshared(additional uint16, b []byte) (db []byte) { // Decrypt the message and check it wasn't tampered with. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) - db, _ = l.aead.Open(nil, nonce, ciphertext, buf[:]) + db, _ = aead.Open(nil, nonce, ciphertext, buf[:]) return } diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index 22b22be..0cb46a5 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -3,6 +3,8 @@ package link import ( "bytes" "crypto/rand" + "encoding/binary" + "encoding/hex" "io" "testing" @@ -32,20 +34,47 @@ func TestXOR(t *testing.T) { } func TestXChacha20(t *testing.T) { - l := Link{} k := make([]byte, 32) _, err := rand.Read(k) if err != nil { t.Fatal(err) } - l.aead, err = chacha20poly1305.NewX(k) + aead, err := chacha20poly1305.NewX(k) if err != nil { t.Fatal(err) } data := []byte("12345678") for i := uint64(0); i < 100000; i++ { - if !bytes.Equal(l.DecodePreshared(uint16(i), l.EncodePreshared(uint16(i), data)), data) { + if !bytes.Equal(decode(aead, uint16(i), encode(aead, uint16(i), data)), data) { t.Fatal("unexpected preshared at", i, "addt", uint16(i)) } } } + +func TestExpandKeyUnit(t *testing.T) { + k1 := byte(0b10001010) + k2 := byte(0b10111010) // rev 01011101 + v := expandkeyunit(k1, k2) // x1x0x0x0x1x0x1x0 | 0x1x0x1x1x1x0x1x = 0110001011100110 + if v != 0b0110001011100110 { + buf := [2]byte{} + binary.BigEndian.PutUint16(buf[:], v) + t.Fatal(hex.EncodeToString(buf[:])) + } +} + +func TestMixKeys(t *testing.T) { + k1, _ := hex.DecodeString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + k2, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000000") + k := mixkeys(k1, k2) + kexp, _ := hex.DecodeString("55555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555") + if !bytes.Equal(k, kexp) { + t.Fatal(hex.EncodeToString(k)) + } + k1, _ = hex.DecodeString("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + k2, _ = hex.DecodeString("deadbeef1239876540deadbeef1239876540deadbeef1239876540abcdef4567") + k = mixkeys(k1, k2) + kexp, _ = hex.DecodeString("2ca9188d3ebb4a9f22e34d4479d857fca48390253ebbe23f22cbcf6e59507ddc06a9b08794316abfa26b67cedb7a5d542c8912adb493c0352aebe76e73dadf7e") + if !bytes.Equal(k, kexp) { + t.Fatal(hex.EncodeToString(k)) + } +} diff --git a/gold/link/link.go b/gold/link/link.go index 0221bd2..bc0613c 100644 --- a/gold/link/link.go +++ b/gold/link/link.go @@ -4,11 +4,11 @@ import ( "crypto/cipher" "errors" "net" + "sync/atomic" "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/helper" base14 "github.com/fumiama/go-base16384" - tea "github.com/fumiama/gofastTEA" ) // Link 是本机到 peer 的连接抽象 @@ -27,10 +27,8 @@ type Link struct { endpoint *net.UDPAddr // 本机允许接收/发送的 ip 网段 allowedips []*net.IPNet - // 连接所用对称加密密钥 - key []tea.TEA - // 连接所用预共享密钥 - aead cipher.AEAD + // 连接所用对称加密密钥集 + keys [32]cipher.AEAD // 本机信息 me *Me // 连接的状态,详见下方 const @@ -84,3 +82,7 @@ func (l *Link) String() (n string) { } return } + +func (l *Link) incgetsndcnt() uintptr { + return atomic.AddUintptr(&l.sendcount, 1) +} diff --git a/gold/link/listen.go b/gold/link/listen.go index 2d1d0c6..86f7fb3 100644 --- a/gold/link/listen.go +++ b/gold/link/listen.go @@ -2,9 +2,11 @@ package link import ( "bytes" + "errors" "io" "net" "net/netip" + "os" "runtime" "strconv" "sync" @@ -48,7 +50,18 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) { } logrus.Debugln("[listen] lock index", i) lbf := listenbuff[i*65536 : (i+1)*65536] + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + logrus.Warnln("[listen] set ddl err:", err) + } n, addr, err := conn.ReadFromUDP(lbf) + if m.loop == nil { + logrus.Warnln("[listen] quit listening") + return + } + if errors.Is(err, os.ErrDeadlineExceeded) { + err = nil + } if err != nil { logrus.Warnln("[listen] read from udp err, reconnect:", err) conn, err = net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(m.udpep.String()))) @@ -102,15 +115,12 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin } switch { case p.IsToMe(packet.Dst): - packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>28), packet.Data) - if p.aead != nil { - addt := packet.AdditionalData() - packet.Data = p.DecodePreshared(addt, packet.Data) - if packet.Data == nil { - logrus.Debugln("[listen] @", index, "drop invalid preshared packet, addt:", addt) - packet.Put() - return - } + addt := packet.AdditionalData() + packet.Data = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data) + if packet.Data == nil { + logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt) + packet.Put() + return } if p.usezstd { dec, _ := zstd.NewReader(bytes.NewReader(packet.Data)) diff --git a/gold/link/me.go b/gold/link/me.go index c89d6fa..86dd7aa 100644 --- a/gold/link/me.go +++ b/gold/link/me.go @@ -121,9 +121,25 @@ func (m *Me) MTU() uint16 { return m.mtu } -func (m *Me) CloseNIC() error { - m.nic.Down() - return m.nic.Close() +func (m *Me) EndPoint() net.Addr { + return m.udpep +} + +func (m *Me) Close() error { + m.loop = nil + m.connections = nil + _ = m.udpconn.Close() + m.udpconn = nil + m.router = nil + m.recving.Destroy() + m.recving = nil + m.recved.Destroy() + m.recved = nil + if m.nic != nil { + m.nic.Down() + return m.nic.Close() + } + return nil } func (m *Me) Write(packet []byte) (n int, err error) { diff --git a/gold/link/peer.go b/gold/link/peer.go index 631c047..98505cb 100644 --- a/gold/link/peer.go +++ b/gold/link/peer.go @@ -6,7 +6,6 @@ import ( "github.com/fumiama/WireGold/gold/head" curve "github.com/fumiama/go-x25519" - tea "github.com/fumiama/gofastTEA" "github.com/sirupsen/logrus" "golang.org/x/crypto/chacha20poly1305" ) @@ -48,21 +47,28 @@ func (m *Me) AddPeer(cfg *PeerConfig) (l *Link) { if !cfg.NoPipe { l.pipe = make(chan *head.Packet, 32) } + var k, p []byte if cfg.PubicKey != nil { - c := curve.Get(m.privKey[:]) - k, err := c.Shared(cfg.PubicKey) - if err == nil { - l.key = make([]tea.TEA, 16) - for i := range l.key { - l.key[i] = tea.NewTeaCipherLittleEndian(k[i : 16+i]) - } - } + k, _ = curve.Get(m.privKey[:]).Shared(cfg.PubicKey) } if cfg.PresharedKey != nil { + p = cfg.PresharedKey[:] + } + if len(k) == 32 { var err error - l.aead, err = chacha20poly1305.NewX(cfg.PresharedKey[:]) - if err != nil { - panic(err) + if len(p) == 32 { + mixk := mixkeys(k, p) + for i := range k { + l.keys[i], err = chacha20poly1305.NewX(mixk[i : i+32]) + if err != nil { + panic(err) + } + } + } else { + l.keys[0], err = chacha20poly1305.NewX(k) + if err != nil { + panic(err) + } } } if cfg.EndPoint != "" { diff --git a/gold/link/router.go b/gold/link/router.go index dc9297a..452152d 100644 --- a/gold/link/router.go +++ b/gold/link/router.go @@ -10,7 +10,7 @@ import ( ) type Router struct { - sync.RWMutex + mu sync.RWMutex // map[cidr]*Link table map[string]*Link list []*net.IPNet @@ -35,10 +35,10 @@ func (l *Link) IsToMe(ip net.IP) bool { // SetDefault 设置默认网关 func (r *Router) SetDefault(l *Link) { defnet := &net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - r.Lock() + r.mu.Lock() r.list[len(r.list)-1] = defnet r.table[defnet.String()] = l - r.Unlock() + r.mu.Unlock() } // NextHop 得到前往 ip 的下一跳的 link @@ -56,8 +56,8 @@ func (r *Router) NextHop(ip string) (l *Link) { // 遍历 r.table,得到正确的下一跳 // 注意使用 r.mu 读写锁避免竞争 - r.RLock() - defer r.RUnlock() + r.mu.RLock() + defer r.mu.RUnlock() for _, c := range r.list { if c.Contains(ipb) { @@ -75,7 +75,7 @@ func (r *Router) NextHop(ip string) (l *Link) { // SetItem 添加一条表项 func (r *Router) SetItem(ip *net.IPNet, l *Link) { - r.Lock() + r.mu.Lock() // 从第一条表项开始匹配 for i := 0; i < len(r.list); i++ { if r.list[i].Contains(ip.IP) { @@ -94,7 +94,7 @@ func (r *Router) SetItem(ip *net.IPNet, l *Link) { break } } - r.Unlock() + r.mu.Unlock() } func isSubnetBcast(ip net.IP, subnet *net.IPNet) bool { diff --git a/gold/link/send.go b/gold/link/send.go index d10d721..229479c 100644 --- a/gold/link/send.go +++ b/gold/link/send.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "math/rand" - "sync/atomic" "github.com/fumiama/WireGold/gold/head" "github.com/fumiama/WireGold/helper" @@ -18,15 +17,15 @@ import ( // WriteAndPut 向 peer 发包并将包放回缓存池 func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { defer p.Put() - teatype := uint8(rand.Intn(16)) - sndcnt := atomic.AddUintptr(&l.sendcount, 1) + teatype := l.randkeyidx() + sndcnt := uint16(l.incgetsndcnt()) mtu := l.mtu if l.mturandomrange > 0 { mtu -= uint16(rand.Intn(int(l.mturandomrange))) } - logrus.Debugln("[send] mtu:", mtu, ", count:", sndcnt, ", additional data:", uint16(sndcnt)&0x0fff) + logrus.Debugln("[send] mtu:", mtu, ", addt:", uint16(sndcnt)&0x0fff, ", key index:", teatype) if !istransfer { - l.encrypt(p, uint16(sndcnt), teatype) + l.encrypt(p, sndcnt, teatype) } delta := (int(mtu) - 60) & 0x0000fff8 if delta <= 0 { @@ -34,7 +33,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { delta = 8 } if len(p.Data) <= delta { - return l.write(p, teatype, uint16(sndcnt), uint32(len(p.Data)), 0, istransfer, false) + return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false) } if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta { return 0, errors.New("drop don't fragmnet big trans packet") @@ -48,7 +47,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { for ; int(totl)-pos > delta; pos += delta { logrus.Debugln("[send] split frag [", pos, "~", pos+delta, "], remain:", int(totl)-pos-delta) packet.Data = data[:delta] - cnt, err := l.write(packet, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, true) + cnt, err := l.write(packet, teatype, sndcnt, totl, uint16(pos>>3), istransfer, true) n += cnt if err != nil { return n, err @@ -60,7 +59,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) { if len(data) > 0 { p.Data = data cnt := 0 - cnt, err = l.write(p, teatype, uint16(sndcnt), totl, uint16(pos>>3), istransfer, false) + cnt, err = l.write(p, teatype, sndcnt, totl, uint16(pos>>3), istransfer, false) n += cnt } return n, err @@ -78,12 +77,8 @@ func (l *Link) encrypt(p *head.Packet, sndcnt uint16, teatype uint8) { p.Data = w.Bytes() logrus.Debugln("[send] data len after zstd:", len(p.Data)) } - if l.aead != nil { - p.Data = l.EncodePreshared(sndcnt&0x0fff, p.Data) - logrus.Debugln("[send] data len after xchacha20:", len(p.Data)) - } - p.Data = l.Encode(teatype, p.Data) - logrus.Debugln("[send] data len after tea:", len(p.Data)) + p.Data = l.Encode(teatype, sndcnt&0x07ff, p.Data) + logrus.Debugln("[send] data len after xchacha20:", len(p.Data), "addt:", sndcnt) } // write 向 peer 发一个包 diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index a67ccb3..0a4bed1 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -146,6 +146,7 @@ func (s *Tunnel) handleRead() { delete(seqmap, seq) seq++ s.out <- p + continue } p := s.l.Read() if p == nil { diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 486f7f1..780f780 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -16,10 +16,7 @@ import ( "github.com/fumiama/WireGold/helper" ) -func TestTunnel(t *testing.T) { - logrus.SetLevel(logrus.DebugLevel) - logrus.SetFormatter(&logFormat{enableColor: false}) - +func testTunnel(t *testing.T, isplain bool, pshk *[32]byte) { selfpk, err := curve.New(nil) if err != nil { panic(err) @@ -35,34 +32,47 @@ func TestTunnel(t *testing.T) { m := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.2/32", - MyEndpoint: "127.0.0.1:21246", + MyEndpoint: "127.0.0.1:0", PrivateKey: selfpk.Private(), SrcPort: 1, DstPort: 1, MTU: 4096, }) - m.AddPeer(&link.PeerConfig{ - PeerIP: "192.168.1.3", - EndPoint: "127.0.0.1:21247", - AllowedIPs: []string{"192.168.1.3/32"}, - PubicKey: peerpk.Public(), - MTU: 4096, - MTURandomRange: 1024, - UseZstd: true, - }) + defer m.Close() + p := link.NewMe(&link.MyConfig{ MyIPwithMask: "192.168.1.3/32", - MyEndpoint: "127.0.0.1:21247", + MyEndpoint: "127.0.0.1:0", PrivateKey: peerpk.Private(), SrcPort: 1, DstPort: 1, MTU: 4096, }) + defer p.Close() + + ppp := peerpk.Public() + spp := selfpk.Public() + if isplain { + ppp = nil + spp = nil + } + + m.AddPeer(&link.PeerConfig{ + PeerIP: "192.168.1.3", + EndPoint: p.EndPoint().String(), + AllowedIPs: []string{"192.168.1.3/32"}, + PubicKey: ppp, + PresharedKey: pshk, + MTU: 4096, + MTURandomRange: 1024, + UseZstd: true, + }) p.AddPeer(&link.PeerConfig{ PeerIP: "192.168.1.2", - EndPoint: "127.0.0.1:21246", + EndPoint: m.EndPoint().String(), AllowedIPs: []string{"192.168.1.2/32"}, - PubicKey: selfpk.Public(), + PubicKey: spp, + PresharedKey: pshk, MTU: 4096, MTURandomRange: 1024, UseZstd: true, @@ -121,7 +131,7 @@ func TestTunnel(t *testing.T) { tunnme.Write(sendb) rd := bytes.NewBuffer(nil) - tm := time.AfterFunc(time.Second*5, func() { + tm := time.AfterFunc(time.Second*2, func() { tunnme.Stop() tunnpeer.Stop() }) @@ -136,6 +146,22 @@ func TestTunnel(t *testing.T) { } } +func TestTunnel(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + logrus.SetFormatter(&logFormat{enableColor: false}) + + testTunnel(t, true, nil) // test plain text + + testTunnel(t, false, nil) // test normal + + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + panic(err) + } + testTunnel(t, false, &buf) // test preshared +} + // logFormat specialize for go-cqhttp type logFormat struct { enableColor bool diff --git a/upper/services/wg/wg.go b/upper/services/wg/wg.go index cc8dba8..4d3e1a1 100644 --- a/upper/services/wg/wg.go +++ b/upper/services/wg/wg.go @@ -61,7 +61,7 @@ func (wg *WG) Run(srcport, destport uint16) { } func (wg *WG) Stop() { - _ = wg.me.CloseNIC() + _ = wg.me.Close() } func (wg *WG) init(srcport, dstport uint16) {