1
0
mirror of https://github.com/fumiama/aes-rsa-tcp-demo.git synced 2026-06-27 07:20:27 +08:00
This commit is contained in:
源文雨
2023-12-29 16:29:41 +09:00
commit f79b275f42
13 changed files with 827 additions and 0 deletions

94
utils/aes.go Normal file
View File

@@ -0,0 +1,94 @@
package utils
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
)
// NewAESPresharedKey typ AES-128, AES-192, or AES-256.
func NewAESPresharedKey(typ string) ([]byte, error) {
sz := 0
switch typ {
case "AES-128":
sz = 16
case "AES-192":
sz = 24
case "AES-256":
sz = 32
}
if sz <= 0 {
return nil, aes.KeySizeError(sz)
}
k := make([]byte, sz)
_, err := rand.Read(k)
if err != nil {
return nil, err
}
return k, nil
}
// EncryptAES ...
func EncryptAES(aescipher cipher.Block, data []byte) []byte {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
if total%blksz > 0 {
n++
}
encdat := make([]byte, blksz*n)
copy(encdat, data)
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Encrypt(encdat[a:b], encdat[a:b])
}
return encdat
}
// EncryptAESInplace ...
func EncryptAESInplace(aescipher cipher.Block, data []byte) error {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
if total%blksz > 0 {
n++
}
if len(data) < blksz*n {
return errors.New("data is too short")
}
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Encrypt(data[a:b], data[a:b])
}
return nil
}
// DecryptAES ...
func DecryptAES(aescipher cipher.Block, data []byte) []byte {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
decdat := make([]byte, blksz*n)
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Decrypt(decdat[a:b], data[a:b])
}
return decdat
}
// DecryptAESInplace ...
func DecryptAESInplace(aescipher cipher.Block, data []byte) error {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Decrypt(data[a:b], data[a:b])
}
return nil
}

60
utils/aes_test.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/rand"
"encoding/hex"
"testing"
)
func TestAES(t *testing.T) {
var buf [32]byte
var data [123]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
encdat := EncryptAES(aescipher, data[:])
decdat := DecryptAES(aescipher, encdat)[:123]
if !bytes.Equal(data[:], decdat) {
t.Fatal("expected " + hex.EncodeToString(data[:]) + " but got " + hex.EncodeToString(decdat))
}
}
func TestAESInplace(t *testing.T) {
var buf [32]byte
var data [128]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
org := data
err = EncryptAESInplace(aescipher, data[:])
if err != nil {
t.Fatal(err)
}
err = DecryptAESInplace(aescipher, data[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data[:], org[:]) {
t.Fatal("expected " + hex.EncodeToString(org[:]) + " but got " + hex.EncodeToString(data[:]))
}
}

77
utils/packet.go Normal file
View File

@@ -0,0 +1,77 @@
package utils
import (
"encoding/binary"
"errors"
"io"
"math"
)
type PacketType uint8
const (
PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key
PacketTypeComm // PacketTypeComm normal communication
PacketTypeTop // PacketTypeTop for valid checking
)
// Packet is the communicating proxy
type Packet struct {
Len uint16 // Len packet length LE
Typ PacketType // Typ packet type
Dat []byte // Dat payload
}
// ParsePacket from bytes
func ParsePacket(d []byte) (p Packet, err error) {
l := binary.LittleEndian.Uint16(d[:2])
if 2+int(l) != len(d) {
err = errors.New("invalid packet len")
return
}
p.Len = l
p.Typ = PacketType(d[2])
p.Dat = d[3:]
return
}
// ReadPacket from io.Reader
func ReadPacket(r io.Reader) (p Packet, err error) {
var buf [2]byte
_, err = io.ReadFull(r, buf[:])
if err != nil {
return
}
l := binary.LittleEndian.Uint16(buf[:])
data := make([]byte, l)
_, err = io.ReadFull(r, data)
if err != nil {
return
}
p.Len = l
p.Typ = PacketType(data[0])
p.Dat = data[1:]
return
}
// ToBytes marshal packet into bytes
func (p *Packet) ToBytes() ([]byte, error) {
l := 1 + len(p.Dat)
if l > math.MaxUint16 {
return nil, errors.New("packet data too large")
}
if p.Typ >= PacketTypeTop {
return nil, errors.New("invalid packet tpye")
}
p.Len = uint16(l)
return p.MustToBytes(), nil
}
// MustToBytes don't do any check
func (p *Packet) MustToBytes() []byte {
data := make([]byte, 2+1+len(p.Dat))
binary.LittleEndian.PutUint16(data[:2], p.Len)
data[2] = byte(p.Typ)
copy(data[3:], p.Dat)
return data
}

79
utils/packet_init.go Normal file
View File

@@ -0,0 +1,79 @@
package utils
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"hash/crc64"
)
type PacketInitType uint8
const (
PacketInitTypeReq PacketInitType = iota // PacketInitTypeReq request RSA pubkey (by client)
PacketInitTypeAck // PacketInitTypeAck give the key
PacketInitTypeTop
)
// NewPacketInit x509rsapubkey = nil for req
func NewPacketInit(aescipher cipher.Block, x509rsapubkey []byte) ([]byte, error) {
if len(x509rsapubkey) == 0 {
return []byte{byte(PacketInitTypeReq)}, nil
}
blksz := aescipher.BlockSize()
total := 2 + len(x509rsapubkey) + 8
n := total / blksz
if total%blksz > 0 {
n++
}
data := make([]byte, 1+blksz*n)
data[0] = byte(PacketInitTypeAck)
encdat := data[1:]
binary.LittleEndian.PutUint16(encdat[:2], uint16(len(x509rsapubkey)))
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err := h.Write(x509rsapubkey)
if err != nil {
return nil, err
}
_ = h.Sum(encdat[2 : 2 : 2+8])
copy(encdat[2+8:], x509rsapubkey)
err = EncryptAESInplace(aescipher, encdat)
if err != nil {
return nil, err
}
return data, nil
}
// ParsePacketInit parse a init packet
func ParsePacketInit(aescipher cipher.Block, d []byte) (x509rsapubkey []byte, err error) {
if len(d) == 0 {
err = errors.New("invalid init packet length")
return
}
if d[0] >= byte(PacketInitTypeTop) {
err = errors.New("invalid init packet type")
return
}
if d[0] == byte(PacketInitTypeReq) {
return
}
data := DecryptAES(aescipher, d[1:])
klen := binary.LittleEndian.Uint16(data[:2])
if int(klen) > len(data[10:]) {
err = errors.New("invalid init packet data length")
return
}
x509rsapubkey = data[10 : 10+klen]
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err = h.Write(x509rsapubkey)
if err != nil {
return nil, err
}
var buf [8]byte
if !bytes.Equal(data[2:2+8], h.Sum(buf[:0])) {
err = errors.New("invalid init packet data")
return
}
return
}

49
utils/packet_init_test.go Normal file
View File

@@ -0,0 +1,49 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"testing"
)
func TestPacketInit(t *testing.T) {
var buf [32]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
k, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
rsak := x509.MarshalPKCS1PublicKey(&k.PublicKey)
data, err := NewPacketInit(aescipher, rsak)
if err != nil {
t.Fatal(err)
}
parsedk, err := ParsePacketInit(aescipher, data)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(rsak, parsedk) {
t.Fatal("unexpected 1")
}
data, err = NewPacketInit(nil, nil)
if err != nil {
t.Fatal(err)
}
parsedk, err = ParsePacketInit(aescipher, data)
if err != nil {
t.Fatal(err)
}
if len(parsedk) > 0 {
t.Fatal("unexpected 2")
}
}

56
utils/rsa.go Normal file
View File

@@ -0,0 +1,56 @@
package utils
import (
"crypto/rsa"
"encoding/binary"
"errors"
"hash/crc64"
"math"
_ "unsafe"
)
//go:linkname encrypt crypto/rsa.encrypt
func encrypt(pub *rsa.PublicKey, plaintext []byte) ([]byte, error)
//go:linkname decrypt crypto/rsa.decrypt
func decrypt(priv *rsa.PrivateKey, ciphertext []byte, check bool) ([]byte, error)
// RSAPrivateKeyEncrypt use the method generally in sign
func RSAPrivateKeyEncrypt(priv *rsa.PrivateKey, plaintext []byte) ([]byte, error) {
if len(plaintext) > math.MaxUint16 {
return nil, errors.New("plaintext too large")
}
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err := h.Write(plaintext)
if err != nil {
return nil, err
}
data := make([]byte, len(plaintext)+8+2)
n := copy(data[:], plaintext)
binary.LittleEndian.PutUint64(data[n:n+8], h.Sum64())
binary.LittleEndian.PutUint16(data[n+8:n+8+2], uint16(len(plaintext)))
return decrypt(priv, data, false)
}
// RSAPublicKeyDecrypt use the method generally in sign
func RSAPublicKeyDecrypt(pub *rsa.PublicKey, ciphertext []byte) ([]byte, error) {
data, err := encrypt(pub, ciphertext)
if err != nil {
return nil, err
}
n := binary.LittleEndian.Uint16(data[len(data)-2:])
p := len(data) - int(n) - 8 - 2
if p < 0 || p > len(data) {
return nil, errors.New("invalid ciphertext length")
}
data = data[p:]
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err = h.Write(data[:len(data)-8-2])
if err != nil {
return nil, err
}
if h.Sum64() != binary.LittleEndian.Uint64(data[len(data)-8-2:len(data)-2]) {
return nil, errors.New("invalid ciphertext")
}
return data[:len(data)-8-2], nil
}

27
utils/rsa_test.go Normal file
View File

@@ -0,0 +1,27 @@
package utils
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"testing"
)
func TestRSA(t *testing.T) {
testtext := []byte("test RSAPrivateKeyEncrypt")
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
enc, err := RSAPrivateKeyEncrypt(priv, testtext)
if err != nil {
t.Fatal(err)
}
dec, err := RSAPublicKeyDecrypt(&priv.PublicKey, enc)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(testtext, dec) {
t.Fatal("expected", string(testtext), "but got", string(dec))
}
}