diff --git a/gold/link/crypto.go b/gold/link/crypto.go index 2dcbe5e..070faf9 100644 --- a/gold/link/crypto.go +++ b/gold/link/crypto.go @@ -84,19 +84,20 @@ func (l *Link) Decode(teatype uint8, additional uint16, b []byte) (db []byte, er } // encode 使用 xchacha20poly1305 加密 -func encode(aead cipher.AEAD, additional uint16, b []byte) (eb []byte) { +func encode(aead cipher.AEAD, additional uint16, b []byte) []byte { nsz := aead.NonceSize() - // Select a random nonce, and leave capacity for the ciphertext. - nonce := make([]byte, nsz, nsz+len(b)+aead.Overhead()) + // Accocate capacity for all the stuffs. + buf := make([]byte, 2+nsz+len(b)+aead.Overhead()) + binary.LittleEndian.PutUint16(buf[:2], additional) + nonce := buf[2 : 2+nsz] + // Select a random nonce _, err := rand.Read(nonce) if err != nil { - return + panic(err) } // Encrypt the message and append the ciphertext to the nonce. - var buf [2]byte - binary.LittleEndian.PutUint16(buf[:], additional) - eb = aead.Seal(nonce, nonce, b, buf[:]) - return + eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf[:2]) + return nonce[:nsz+len(eb)] } // decode 使用 xchacha20poly1305 解密 @@ -107,6 +108,9 @@ func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { } // Split nonce and ciphertext. nonce, ciphertext := b[:nsz], b[nsz:] + if len(ciphertext) == 0 { + return nil, nil + } // Decrypt the message and check it wasn't tampered with. var buf [2]byte binary.LittleEndian.PutUint16(buf[:], additional) diff --git a/gold/link/crypto_test.go b/gold/link/crypto_test.go index c3b45a8..4287f0d 100644 --- a/gold/link/crypto_test.go +++ b/gold/link/crypto_test.go @@ -43,14 +43,18 @@ func TestXChacha20(t *testing.T) { if err != nil { t.Fatal(err) } - data := []byte("12345678") - for i := uint64(0); i < 100000; i++ { - db, err := decode(aead, uint16(i), encode(aead, uint16(i), data)) + data := make([]byte, 4096) + _, err = rand.Read(data) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 4096; i++ { + db, err := decode(aead, uint16(i), encode(aead, uint16(i), data[:i])) if err != nil { t.Fatal(err) } - if !bytes.Equal(db, data) { - t.Fatal("unexpected preshared at", i, "addt", uint16(i)) + if !bytes.Equal(db, data[:i]) { + t.Fatal("unexpected preshared at idx(len)", i, "addt", uint16(i)) } } }