mirror of
https://github.com/fumiama/aes-rsa-tcp-demo.git
synced 2026-06-27 07:20:27 +08:00
init
This commit is contained in:
94
utils/aes.go
Normal file
94
utils/aes.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// NewAESPresharedKey typ AES-128, AES-192, or AES-256.
|
||||
func NewAESPresharedKey(typ string) ([]byte, error) {
|
||||
sz := 0
|
||||
switch typ {
|
||||
case "AES-128":
|
||||
sz = 16
|
||||
case "AES-192":
|
||||
sz = 24
|
||||
case "AES-256":
|
||||
sz = 32
|
||||
}
|
||||
if sz <= 0 {
|
||||
return nil, aes.KeySizeError(sz)
|
||||
}
|
||||
k := make([]byte, sz)
|
||||
_, err := rand.Read(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// EncryptAES ...
|
||||
func EncryptAES(aescipher cipher.Block, data []byte) []byte {
|
||||
blksz := aescipher.BlockSize()
|
||||
total := len(data)
|
||||
n := total / blksz
|
||||
if total%blksz > 0 {
|
||||
n++
|
||||
}
|
||||
encdat := make([]byte, blksz*n)
|
||||
copy(encdat, data)
|
||||
for i := 0; i < n; i++ {
|
||||
a := i * blksz
|
||||
b := (i + 1) * blksz
|
||||
aescipher.Encrypt(encdat[a:b], encdat[a:b])
|
||||
}
|
||||
return encdat
|
||||
}
|
||||
|
||||
// EncryptAESInplace ...
|
||||
func EncryptAESInplace(aescipher cipher.Block, data []byte) error {
|
||||
blksz := aescipher.BlockSize()
|
||||
total := len(data)
|
||||
n := total / blksz
|
||||
if total%blksz > 0 {
|
||||
n++
|
||||
}
|
||||
if len(data) < blksz*n {
|
||||
return errors.New("data is too short")
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
a := i * blksz
|
||||
b := (i + 1) * blksz
|
||||
aescipher.Encrypt(data[a:b], data[a:b])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptAES ...
|
||||
func DecryptAES(aescipher cipher.Block, data []byte) []byte {
|
||||
blksz := aescipher.BlockSize()
|
||||
total := len(data)
|
||||
n := total / blksz
|
||||
decdat := make([]byte, blksz*n)
|
||||
for i := 0; i < n; i++ {
|
||||
a := i * blksz
|
||||
b := (i + 1) * blksz
|
||||
aescipher.Decrypt(decdat[a:b], data[a:b])
|
||||
}
|
||||
return decdat
|
||||
}
|
||||
|
||||
// DecryptAESInplace ...
|
||||
func DecryptAESInplace(aescipher cipher.Block, data []byte) error {
|
||||
blksz := aescipher.BlockSize()
|
||||
total := len(data)
|
||||
n := total / blksz
|
||||
for i := 0; i < n; i++ {
|
||||
a := i * blksz
|
||||
b := (i + 1) * blksz
|
||||
aescipher.Decrypt(data[a:b], data[a:b])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
60
utils/aes_test.go
Normal file
60
utils/aes_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAES(t *testing.T) {
|
||||
var buf [32]byte
|
||||
var data [123]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rand.Read(data[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aescipher, err := aes.NewCipher(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encdat := EncryptAES(aescipher, data[:])
|
||||
decdat := DecryptAES(aescipher, encdat)[:123]
|
||||
if !bytes.Equal(data[:], decdat) {
|
||||
t.Fatal("expected " + hex.EncodeToString(data[:]) + " but got " + hex.EncodeToString(decdat))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESInplace(t *testing.T) {
|
||||
var buf [32]byte
|
||||
var data [128]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = rand.Read(data[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aescipher, err := aes.NewCipher(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
org := data
|
||||
err = EncryptAESInplace(aescipher, data[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = DecryptAESInplace(aescipher, data[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data[:], org[:]) {
|
||||
t.Fatal("expected " + hex.EncodeToString(org[:]) + " but got " + hex.EncodeToString(data[:]))
|
||||
}
|
||||
}
|
||||
77
utils/packet.go
Normal file
77
utils/packet.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
)
|
||||
|
||||
type PacketType uint8
|
||||
|
||||
const (
|
||||
PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key
|
||||
PacketTypeComm // PacketTypeComm normal communication
|
||||
PacketTypeTop // PacketTypeTop for valid checking
|
||||
)
|
||||
|
||||
// Packet is the communicating proxy
|
||||
type Packet struct {
|
||||
Len uint16 // Len packet length LE
|
||||
Typ PacketType // Typ packet type
|
||||
Dat []byte // Dat payload
|
||||
}
|
||||
|
||||
// ParsePacket from bytes
|
||||
func ParsePacket(d []byte) (p Packet, err error) {
|
||||
l := binary.LittleEndian.Uint16(d[:2])
|
||||
if 2+int(l) != len(d) {
|
||||
err = errors.New("invalid packet len")
|
||||
return
|
||||
}
|
||||
p.Len = l
|
||||
p.Typ = PacketType(d[2])
|
||||
p.Dat = d[3:]
|
||||
return
|
||||
}
|
||||
|
||||
// ReadPacket from io.Reader
|
||||
func ReadPacket(r io.Reader) (p Packet, err error) {
|
||||
var buf [2]byte
|
||||
_, err = io.ReadFull(r, buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l := binary.LittleEndian.Uint16(buf[:])
|
||||
data := make([]byte, l)
|
||||
_, err = io.ReadFull(r, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p.Len = l
|
||||
p.Typ = PacketType(data[0])
|
||||
p.Dat = data[1:]
|
||||
return
|
||||
}
|
||||
|
||||
// ToBytes marshal packet into bytes
|
||||
func (p *Packet) ToBytes() ([]byte, error) {
|
||||
l := 1 + len(p.Dat)
|
||||
if l > math.MaxUint16 {
|
||||
return nil, errors.New("packet data too large")
|
||||
}
|
||||
if p.Typ >= PacketTypeTop {
|
||||
return nil, errors.New("invalid packet tpye")
|
||||
}
|
||||
p.Len = uint16(l)
|
||||
return p.MustToBytes(), nil
|
||||
}
|
||||
|
||||
// MustToBytes don't do any check
|
||||
func (p *Packet) MustToBytes() []byte {
|
||||
data := make([]byte, 2+1+len(p.Dat))
|
||||
binary.LittleEndian.PutUint16(data[:2], p.Len)
|
||||
data[2] = byte(p.Typ)
|
||||
copy(data[3:], p.Dat)
|
||||
return data
|
||||
}
|
||||
79
utils/packet_init.go
Normal file
79
utils/packet_init.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc64"
|
||||
)
|
||||
|
||||
type PacketInitType uint8
|
||||
|
||||
const (
|
||||
PacketInitTypeReq PacketInitType = iota // PacketInitTypeReq request RSA pubkey (by client)
|
||||
PacketInitTypeAck // PacketInitTypeAck give the key
|
||||
PacketInitTypeTop
|
||||
)
|
||||
|
||||
// NewPacketInit x509rsapubkey = nil for req
|
||||
func NewPacketInit(aescipher cipher.Block, x509rsapubkey []byte) ([]byte, error) {
|
||||
if len(x509rsapubkey) == 0 {
|
||||
return []byte{byte(PacketInitTypeReq)}, nil
|
||||
}
|
||||
blksz := aescipher.BlockSize()
|
||||
total := 2 + len(x509rsapubkey) + 8
|
||||
n := total / blksz
|
||||
if total%blksz > 0 {
|
||||
n++
|
||||
}
|
||||
data := make([]byte, 1+blksz*n)
|
||||
data[0] = byte(PacketInitTypeAck)
|
||||
encdat := data[1:]
|
||||
binary.LittleEndian.PutUint16(encdat[:2], uint16(len(x509rsapubkey)))
|
||||
h := crc64.New(crc64.MakeTable(crc64.ECMA))
|
||||
_, err := h.Write(x509rsapubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = h.Sum(encdat[2 : 2 : 2+8])
|
||||
copy(encdat[2+8:], x509rsapubkey)
|
||||
err = EncryptAESInplace(aescipher, encdat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ParsePacketInit parse a init packet
|
||||
func ParsePacketInit(aescipher cipher.Block, d []byte) (x509rsapubkey []byte, err error) {
|
||||
if len(d) == 0 {
|
||||
err = errors.New("invalid init packet length")
|
||||
return
|
||||
}
|
||||
if d[0] >= byte(PacketInitTypeTop) {
|
||||
err = errors.New("invalid init packet type")
|
||||
return
|
||||
}
|
||||
if d[0] == byte(PacketInitTypeReq) {
|
||||
return
|
||||
}
|
||||
data := DecryptAES(aescipher, d[1:])
|
||||
klen := binary.LittleEndian.Uint16(data[:2])
|
||||
if int(klen) > len(data[10:]) {
|
||||
err = errors.New("invalid init packet data length")
|
||||
return
|
||||
}
|
||||
x509rsapubkey = data[10 : 10+klen]
|
||||
h := crc64.New(crc64.MakeTable(crc64.ECMA))
|
||||
_, err = h.Write(x509rsapubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var buf [8]byte
|
||||
if !bytes.Equal(data[2:2+8], h.Sum(buf[:0])) {
|
||||
err = errors.New("invalid init packet data")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
49
utils/packet_init_test.go
Normal file
49
utils/packet_init_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPacketInit(t *testing.T) {
|
||||
var buf [32]byte
|
||||
_, err := rand.Read(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
k, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
aescipher, err := aes.NewCipher(buf[:])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsak := x509.MarshalPKCS1PublicKey(&k.PublicKey)
|
||||
data, err := NewPacketInit(aescipher, rsak)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedk, err := ParsePacketInit(aescipher, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(rsak, parsedk) {
|
||||
t.Fatal("unexpected 1")
|
||||
}
|
||||
data, err = NewPacketInit(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedk, err = ParsePacketInit(aescipher, data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(parsedk) > 0 {
|
||||
t.Fatal("unexpected 2")
|
||||
}
|
||||
}
|
||||
56
utils/rsa.go
Normal file
56
utils/rsa.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc64"
|
||||
"math"
|
||||
_ "unsafe"
|
||||
)
|
||||
|
||||
//go:linkname encrypt crypto/rsa.encrypt
|
||||
func encrypt(pub *rsa.PublicKey, plaintext []byte) ([]byte, error)
|
||||
|
||||
//go:linkname decrypt crypto/rsa.decrypt
|
||||
func decrypt(priv *rsa.PrivateKey, ciphertext []byte, check bool) ([]byte, error)
|
||||
|
||||
// RSAPrivateKeyEncrypt use the method generally in sign
|
||||
func RSAPrivateKeyEncrypt(priv *rsa.PrivateKey, plaintext []byte) ([]byte, error) {
|
||||
if len(plaintext) > math.MaxUint16 {
|
||||
return nil, errors.New("plaintext too large")
|
||||
}
|
||||
h := crc64.New(crc64.MakeTable(crc64.ECMA))
|
||||
_, err := h.Write(plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]byte, len(plaintext)+8+2)
|
||||
n := copy(data[:], plaintext)
|
||||
binary.LittleEndian.PutUint64(data[n:n+8], h.Sum64())
|
||||
binary.LittleEndian.PutUint16(data[n+8:n+8+2], uint16(len(plaintext)))
|
||||
return decrypt(priv, data, false)
|
||||
}
|
||||
|
||||
// RSAPublicKeyDecrypt use the method generally in sign
|
||||
func RSAPublicKeyDecrypt(pub *rsa.PublicKey, ciphertext []byte) ([]byte, error) {
|
||||
data, err := encrypt(pub, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint16(data[len(data)-2:])
|
||||
p := len(data) - int(n) - 8 - 2
|
||||
if p < 0 || p > len(data) {
|
||||
return nil, errors.New("invalid ciphertext length")
|
||||
}
|
||||
data = data[p:]
|
||||
h := crc64.New(crc64.MakeTable(crc64.ECMA))
|
||||
_, err = h.Write(data[:len(data)-8-2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if h.Sum64() != binary.LittleEndian.Uint64(data[len(data)-8-2:len(data)-2]) {
|
||||
return nil, errors.New("invalid ciphertext")
|
||||
}
|
||||
return data[:len(data)-8-2], nil
|
||||
}
|
||||
27
utils/rsa_test.go
Normal file
27
utils/rsa_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRSA(t *testing.T) {
|
||||
testtext := []byte("test RSAPrivateKeyEncrypt")
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
enc, err := RSAPrivateKeyEncrypt(priv, testtext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dec, err := RSAPublicKeyDecrypt(&priv.PublicKey, enc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(testtext, dec) {
|
||||
t.Fatal("expected", string(testtext), "but got", string(dec))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user