1
0
mirror of https://github.com/fumiama/aes-rsa-tcp-demo.git synced 2026-06-05 01:20:24 +08:00
This commit is contained in:
源文雨
2023-12-29 16:29:41 +09:00
commit f79b275f42
13 changed files with 827 additions and 0 deletions

63
README.md Normal file
View File

@@ -0,0 +1,63 @@
# C/S demo with RSA and AES Algorithm
This program is my homework for the IoT course, demonstrating the use of symmetric encryption AES and asymmetric encryption RSA.
## Communication Process
### Server
1. Create a socket and listen for TCP connections.
2. Upon receiving a connection, perform corresponding operations based on the client's request:
- `utils.PacketTypeInit`: Encrypt its own RSA public key with the pre-shared AES key and send it to the client.
- `utils.PacketTypeComm`: Decrypt and print the message sent by the client using `RSA-OAEP` with the private key, then encrypt its replyment with a **self-made** signature algorithm using the RSA private key and send it.
3. When the client actively disconnects, one processing ends.
### Client
1. Establish a TCP connection to the server address.
2. Send a `utils.PacketTypeInit` packet to the server to request the RSA key.
3. Continuously send `Hello` to the server using `EncryptOAEP` while decrypting and printing the message sent by server with its public key using the **self-made** signature algorithm.
4. When the user manually terminates, close the connection to the server.
## Packet Protocol
The overall encapsulation is defined in [utils/packet.go](utils/packet.go).
```
0 15 23
┌─────────┬────┬──────────────┐
│ len │type│ ... data ... │
└─────────┴────┴──────────────┘
```
### len
The length of the packet without itself.
### type
Defined in [utils/packet.go](utils/packet.go).
```go
const (
PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key
PacketTypeComm // PacketTypeComm normal communication
PacketTypeTop // PacketTypeTop for valid checking
)
```
### data
The payload, whose type is described by the `type` field.
#### PacketTypeInit
Defined in [utils/packet_init.go](utils/packet_init.go).
```
0 7 23 87
┌────┬────────┬─────────────────┬──────────────────────┐
│type│ length │ pub key crc64 │ x509rsapubkey │
└────┴────────┴─────────────────┴──────────────────────┘
```
- **type**: `PacketInitTypeReq` or `PacketInitTypeAck`
- **length**: length of `x509rsapubkey`
#### PacketTypeComm
The whole data field is encrypted by RSA and can fill in with any data, which is plain text in this demo.
## Interesting Points
### The Implementation of [Base16384](https://github.com/fumiama/base16384)
Base16384 is a base64-like algorithm designed by me. It can encode binary file to printable utf16be, and vice versa.
In this demo, the [RSA Private Key](rsa_2048_private_x509.b14) and AES key is saved and passed by base16384 format.
### The Usage of Raw RSA Encrypting Method
In the file [utils/rsa.go](utils/rsa.go), I use `go:linkname` to hook the private function of `crypto/rsa` library and realized a **self-made** signature algorithm that can get the decoding result but not just verify whether it is valid (unlike the official method `rsa.VerifyPKCS1v15`).
## Demo
See the video below.

119
client/main.go Normal file
View File

@@ -0,0 +1,119 @@
package main
import (
"crypto/aes"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"flag"
"fmt"
"net"
"net/netip"
"socket/utils"
"time"
base14 "github.com/fumiama/go-base16384"
)
func main() {
server := flag.String("s", "0.0.0.0:12345", "server host:port")
aespsh := flag.String("aes", "矲晊泞柯碍嘖另蚡蔼帀㴂", "the AES preshard key of base16384 format")
flag.Parse()
if *aespsh == "" {
fmt.Println("must give parameter -aes")
return
}
aescipher, err := aes.NewCipher(base14.DecodeFromString(*aespsh))
if err != nil {
fmt.Println("parse AES preshared key error:", err)
return
}
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*server)))
if err != nil {
fmt.Println("Connect to server error:", err)
return
}
defer conn.Close()
fmt.Println("Ronnected to server", conn.RemoteAddr())
// request rsa pubkey
packet := utils.Packet{}
packet.Typ = utils.PacketTypeInit
packet.Dat, err = utils.NewPacketInit(nil, nil)
if err != nil {
fmt.Println("Wrap init packet err:", err)
return
}
data, err := packet.ToBytes()
if err != nil {
fmt.Println("Wrap packet err:", err)
return
}
_, err = conn.Write(data)
if err != nil {
fmt.Println("Write to server", conn.RemoteAddr(), "error:", err)
return
}
packet, err = utils.ReadPacket(conn)
if err != nil {
fmt.Println("Read packet from server", conn.RemoteAddr(), "error:", err)
return
}
if packet.Typ != utils.PacketTypeInit {
fmt.Println("Unexpected packet type from server")
return
}
x509rsapubkey, err := utils.ParsePacketInit(aescipher, packet.Dat)
if err != nil {
fmt.Println("Parse init packet error:", err)
return
}
rsapubk, err := x509.ParsePKCS1PublicKey(x509rsapubkey)
if err != nil {
fmt.Println("Parse x509rsapubkey error:", err)
return
}
fmt.Println("Get x509rsapubkey successfully")
t := time.NewTicker(time.Second)
defer t.Stop()
count := 0
for range t.C {
count++
packet.Typ = utils.PacketTypeComm
data, err = rsa.EncryptOAEP(
md5.New(), rand.Reader, rsapubk,
base14.StringToBytes(fmt.Sprintf("Hello! This is my No.%d communication.", count)),
nil,
)
if err != nil {
fmt.Println("EncryptOAEP err:", err)
return
}
packet.Dat = data
data, err := packet.ToBytes()
if err != nil {
fmt.Println("Wrap packet err:", err)
return
}
_, err = conn.Write(data)
if err != nil {
fmt.Println("Write to server", conn.RemoteAddr(), "error:", err)
return
}
packet, err = utils.ReadPacket(conn)
if err != nil {
fmt.Println("Read packet from server", conn.RemoteAddr(), "error:", err)
return
}
if err != nil {
fmt.Println("Receive from server", conn.RemoteAddr(), "error:", err)
continue
}
data, err = utils.RSAPublicKeyDecrypt(rsapubk, packet.Dat)
if err != nil {
fmt.Println("RSAPublicKeyDecrypt from server", conn.RemoteAddr(), "error:", err)
continue
}
fmt.Println("Receive from server", conn.RemoteAddr(), ":", base14.BytesToString(data))
}
}

7
go.mod Normal file
View File

@@ -0,0 +1,7 @@
module socket
go 1.17
require github.com/fumiama/go-base16384 v1.7.0
require golang.org/x/text v0.3.7 // indirect

15
go.sum Normal file
View File

@@ -0,0 +1,15 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fumiama/go-base16384 v1.7.0 h1:6fep7XPQWxRlh4Hu+KsdH+6+YdUp+w6CwRXtMWSsXCA=
github.com/fumiama/go-base16384 v1.7.0/go.mod h1:OEn+947GV5gsbTAnyuUW/SrfxJYUdYupSIQXOuGOcXM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1 @@
娠湊戈伀亠渐刂裭甯諊劜帱偒傡睔愲撗厼豚樥瞛茪囧蟑圞憟剔堬吁穲箂讄妒拯洍吡寰忡单汪皮寁藉襂竒砹帐嵗诎笵哠茝傉愇奣賆繽帹罔团聥艥廭矑灱炁璨瓌稲剽汄侫斌漺穣灈廞姅换坣諒讻蚜势詈妙萁襴敇蒌狫嵞堨櫣苫姜盾洽暄繕蚩櫞蚚病絼恍灪厴涪丵洔澋谛帥樋变傛撮仺媹趥层窠枚瀅侂挃痁吐蜧甦寶蚊捾眐杝笜慤悀赧缭祘嶷蛜膉歅呸慏瘘灤蘺禓玝縠娄丁亠渐刂僑牵硏攼矄篤膦綵嘟丈仍脶罿盹讫崿弦楶惑悰诏尃揤統夥泘脉哊璜膬恴夠苨汼筹卧咫璘勏徒擼旾厣撆簂噯嗱灄呔繇砊挋枒坭岛国摂嗨痲嵻羂朸浖茙謝璌旹弤蛫烵壣务瀕懘掠怩屆脕焁嫅詇肷洽皈啃晻摹墆竒杰牨諝卛罣猟茂嫔匳岱诩受潙芲誋皪臜徑倁慘庿擈唅廩綵仛簟蒱盍視喖瓡历僅云妙庂兘渷乕御咍职碟蚕濭綶竧瀎寋楌吴嚾苄謷杶赒帨各仧幵噼搯豐竷珳敇暝批豛曂囔穚獙厵舫懱椣跴滭瑼妃臱沚蝣屗缼氌燍褿眎濓翡堬蟾敜敟蓭诮挲牮蒮裚窱摯拉倱搏俀湫炕民喋因貂掳眴捡圫団嵐秶巹褬寘穐覆讐俖袒差矞蓀瘘刃峭璯昏耚罖桅蚑撂犺拨嚋卬蝉緭囊茔牥忞獘呏箷瑺譡翗帠众抶浢緍楼咨暏捽啴茜暛揓羜诠瀼懜哜岫纅詝煑臥傳婲塟挼廎社刦勥瀻諨蠥檍腓瀈挧纞趝嶥媔趮穝举傁彺旇膕亠昁絠葑薯厷峘犙愋硲獈桜諔汫洵滣悫瑢掳瑽覻喉名岓稛偪緧玼嗾瞬傉濯疮蜿瑻荞橁奞啸舠芜殾昧樊坄渻穫萋偀囌云苟专眐巢漞絳嵑班嗥搌嗯蓕覚護創垲櫌蛗孎猦瞄筆蔓璼傁湀圇嚐葬篶痢漞秺婜祠碬毀腫捄绫栙衒淪嘑啍娺穦嘞訬變嵶岈癍薺諺肌菓螕荔紕砢俉勍欰嫍腐褲挺聿衖衪煝栈剸嫜杍旽何寁僌櫞櫲臉瀣焦叨愳氕袏幏刬灩欭垭猄藹嚗絚胁刊侁临怺墖拏燕趝誤譙惫娎犄叚卄帩私蜠蘀娪紨搎咇佪熸蒉蜶淎刼蜟哆捧株溄祭婓坐莛嚐嵝碳睒蔯僕栣芪厵皴袤覵祭朰潭幄櫸藶蔮萄賁簉併苺窩暽潒畩勓耆栀嗅怳曩赎竨痦穠㴃

180
server/main.go Normal file
View File

@@ -0,0 +1,180 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"flag"
"fmt"
"io"
"net"
"net/netip"
"os"
"socket/utils"
"strconv"
"sync/atomic"
base14 "github.com/fumiama/go-base16384"
)
var countnum = uintptr(0)
func main() {
listen := flag.String("l", "0.0.0.0:12345", "listening host:port")
genaespsh := flag.String("gaes", "", "generate a new AES preshard key (AES-128, AES-192, or AES-256)")
genrsakey := flag.Uint("grsa", 0, "generate a new RSA key pair of bits and save it to rsa_bits_private_x509.b14")
rsapkfile := flag.String("rsaf", "rsa_2048_private_x509.b14", "specify the path of the RSA private key")
aespsh := flag.String("aes", "矲晊泞柯碍嘖另蚡蔼帀㴂", "the AES preshard key of base16384 format")
flag.Parse()
if *genaespsh != "" {
k, err := utils.NewAESPresharedKey(*genaespsh)
if err != nil {
fmt.Println("Generate new AES preshard key error:", err)
return
}
fmt.Println("The generated AES preshared key is:", base14.EncodeToString(k))
return
}
if *genrsakey != 0 {
k, err := rsa.GenerateKey(rand.Reader, int(*genrsakey))
if err != nil {
fmt.Println("Generate new RSA key pair error:", err)
return
}
f, err := os.Create(fmt.Sprintf("rsa_%d_private_x509.b14", *genrsakey))
if err != nil {
fmt.Println("Create new RSA private key error:", err)
return
}
defer f.Close()
_, err = f.WriteString(base14.EncodeToString(x509.MarshalPKCS1PrivateKey(k)))
if err != nil {
fmt.Println("Save new RSA private key error:", err)
return
}
fmt.Println("The generated RSA key pair is:")
fmt.Println("Private key: saved into", f.Name())
fmt.Println("Public key:", base14.EncodeToString(x509.MarshalPKCS1PublicKey(&k.PublicKey)))
return
}
rsapkeydata, err := os.ReadFile(*rsapkfile)
if err != nil {
fmt.Println("Read RSA private key error:", err)
return
}
rsaprivkey, err := x509.ParsePKCS1PrivateKey(base14.DecodeFromString(base14.BytesToString(rsapkeydata)))
if err != nil {
fmt.Println("X509 parse RSA private key error:", err)
return
}
x509rsapubkey := x509.MarshalPKCS1PublicKey(&rsaprivkey.PublicKey)
if *aespsh == "" {
fmt.Println("must give parameter -aes")
return
}
aescipher, err := aes.NewCipher(base14.DecodeFromString(*aespsh))
if err != nil {
fmt.Println("parse AES preshared key error:", err)
return
}
listener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*listen)))
if err != nil {
fmt.Println("ListenTCP error:", err)
return
}
defer listener.Close()
fmt.Println("Server bind and listen on", listener.Addr())
for {
conn, err := listener.AcceptTCP()
if err != nil {
fmt.Println("Accept error:", err)
break
}
fmt.Println("Server accept connection from", conn.RemoteAddr())
go handleclient(conn, aescipher, x509rsapubkey, rsaprivkey)
}
}
func handleclient(conn *net.TCPConn, aescipher cipher.Block, x509rsapubkey []byte, rsaprivkey *rsa.PrivateKey) {
var packet utils.Packet
var err error
defer conn.Close()
for {
packet, err = utils.ReadPacket(conn)
if err == io.EOF {
fmt.Println("Client", conn.RemoteAddr(), "closed connection")
return
}
if err != nil {
fmt.Println("Read packet from client", conn.RemoteAddr(), "error:", err)
return
}
switch packet.Typ {
case utils.PacketTypeInit:
rsakeydata, err := utils.ParsePacketInit(aescipher, packet.Dat)
if err != nil {
fmt.Println("Parse init packet from client", conn.RemoteAddr(), "error:", err)
return
}
if len(rsakeydata) > 0 { // unexpected situation
return
}
// send RSA public key by AES encryption
packet.Typ = utils.PacketTypeInit
packet.Dat, err = utils.NewPacketInit(aescipher, x509rsapubkey)
if err != nil {
fmt.Println("Wrap RSA public key init packet to client", conn.RemoteAddr(), "error:", err)
return
}
data, err := packet.ToBytes()
if err != nil {
fmt.Println("Wrap packet to client", conn.RemoteAddr(), "error:", err)
return
}
_, err = conn.Write(data)
if err != nil {
fmt.Println("Send RSA public key to client", conn.RemoteAddr(), "error:", err)
return
}
continue
case utils.PacketTypeComm:
data, err := rsa.DecryptOAEP(md5.New(), rand.Reader, rsaprivkey, packet.Dat, nil)
if err != nil {
fmt.Println("DecryptOAEP from client", conn.RemoteAddr(), "error:", err)
return
}
fmt.Println("Recv data from client", conn.RemoteAddr(), ":", base14.BytesToString(data))
packet.Typ = utils.PacketTypeComm
data, err = utils.RSAPrivateKeyEncrypt(
rsaprivkey, base14.StringToBytes(
"Thank you for connecting! The data is "+
strconv.Itoa(int(atomic.AddUintptr(&countnum, 1))),
),
)
if err != nil {
fmt.Println("RSAPrivateKeyEncrypt to client", conn.RemoteAddr(), "error:", err)
return
}
packet.Dat = data
data, err = packet.ToBytes()
if err != nil {
fmt.Println("Wrap packet to client", conn.RemoteAddr(), "error:", err)
continue
}
_, err = conn.Write(data)
if err != nil {
fmt.Println("Send to client", conn.RemoteAddr(), "error:", err)
continue
}
continue
default:
fmt.Println("Recv unknown packet type from client")
return
}
}
}

94
utils/aes.go Normal file
View File

@@ -0,0 +1,94 @@
package utils
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
)
// NewAESPresharedKey typ AES-128, AES-192, or AES-256.
func NewAESPresharedKey(typ string) ([]byte, error) {
sz := 0
switch typ {
case "AES-128":
sz = 16
case "AES-192":
sz = 24
case "AES-256":
sz = 32
}
if sz <= 0 {
return nil, aes.KeySizeError(sz)
}
k := make([]byte, sz)
_, err := rand.Read(k)
if err != nil {
return nil, err
}
return k, nil
}
// EncryptAES ...
func EncryptAES(aescipher cipher.Block, data []byte) []byte {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
if total%blksz > 0 {
n++
}
encdat := make([]byte, blksz*n)
copy(encdat, data)
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Encrypt(encdat[a:b], encdat[a:b])
}
return encdat
}
// EncryptAESInplace ...
func EncryptAESInplace(aescipher cipher.Block, data []byte) error {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
if total%blksz > 0 {
n++
}
if len(data) < blksz*n {
return errors.New("data is too short")
}
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Encrypt(data[a:b], data[a:b])
}
return nil
}
// DecryptAES ...
func DecryptAES(aescipher cipher.Block, data []byte) []byte {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
decdat := make([]byte, blksz*n)
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Decrypt(decdat[a:b], data[a:b])
}
return decdat
}
// DecryptAESInplace ...
func DecryptAESInplace(aescipher cipher.Block, data []byte) error {
blksz := aescipher.BlockSize()
total := len(data)
n := total / blksz
for i := 0; i < n; i++ {
a := i * blksz
b := (i + 1) * blksz
aescipher.Decrypt(data[a:b], data[a:b])
}
return nil
}

60
utils/aes_test.go Normal file
View File

@@ -0,0 +1,60 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/rand"
"encoding/hex"
"testing"
)
func TestAES(t *testing.T) {
var buf [32]byte
var data [123]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
encdat := EncryptAES(aescipher, data[:])
decdat := DecryptAES(aescipher, encdat)[:123]
if !bytes.Equal(data[:], decdat) {
t.Fatal("expected " + hex.EncodeToString(data[:]) + " but got " + hex.EncodeToString(decdat))
}
}
func TestAESInplace(t *testing.T) {
var buf [32]byte
var data [128]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
_, err = rand.Read(data[:])
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
org := data
err = EncryptAESInplace(aescipher, data[:])
if err != nil {
t.Fatal(err)
}
err = DecryptAESInplace(aescipher, data[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data[:], org[:]) {
t.Fatal("expected " + hex.EncodeToString(org[:]) + " but got " + hex.EncodeToString(data[:]))
}
}

77
utils/packet.go Normal file
View File

@@ -0,0 +1,77 @@
package utils
import (
"encoding/binary"
"errors"
"io"
"math"
)
type PacketType uint8
const (
PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key
PacketTypeComm // PacketTypeComm normal communication
PacketTypeTop // PacketTypeTop for valid checking
)
// Packet is the communicating proxy
type Packet struct {
Len uint16 // Len packet length LE
Typ PacketType // Typ packet type
Dat []byte // Dat payload
}
// ParsePacket from bytes
func ParsePacket(d []byte) (p Packet, err error) {
l := binary.LittleEndian.Uint16(d[:2])
if 2+int(l) != len(d) {
err = errors.New("invalid packet len")
return
}
p.Len = l
p.Typ = PacketType(d[2])
p.Dat = d[3:]
return
}
// ReadPacket from io.Reader
func ReadPacket(r io.Reader) (p Packet, err error) {
var buf [2]byte
_, err = io.ReadFull(r, buf[:])
if err != nil {
return
}
l := binary.LittleEndian.Uint16(buf[:])
data := make([]byte, l)
_, err = io.ReadFull(r, data)
if err != nil {
return
}
p.Len = l
p.Typ = PacketType(data[0])
p.Dat = data[1:]
return
}
// ToBytes marshal packet into bytes
func (p *Packet) ToBytes() ([]byte, error) {
l := 1 + len(p.Dat)
if l > math.MaxUint16 {
return nil, errors.New("packet data too large")
}
if p.Typ >= PacketTypeTop {
return nil, errors.New("invalid packet tpye")
}
p.Len = uint16(l)
return p.MustToBytes(), nil
}
// MustToBytes don't do any check
func (p *Packet) MustToBytes() []byte {
data := make([]byte, 2+1+len(p.Dat))
binary.LittleEndian.PutUint16(data[:2], p.Len)
data[2] = byte(p.Typ)
copy(data[3:], p.Dat)
return data
}

79
utils/packet_init.go Normal file
View File

@@ -0,0 +1,79 @@
package utils
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"hash/crc64"
)
type PacketInitType uint8
const (
PacketInitTypeReq PacketInitType = iota // PacketInitTypeReq request RSA pubkey (by client)
PacketInitTypeAck // PacketInitTypeAck give the key
PacketInitTypeTop
)
// NewPacketInit x509rsapubkey = nil for req
func NewPacketInit(aescipher cipher.Block, x509rsapubkey []byte) ([]byte, error) {
if len(x509rsapubkey) == 0 {
return []byte{byte(PacketInitTypeReq)}, nil
}
blksz := aescipher.BlockSize()
total := 2 + len(x509rsapubkey) + 8
n := total / blksz
if total%blksz > 0 {
n++
}
data := make([]byte, 1+blksz*n)
data[0] = byte(PacketInitTypeAck)
encdat := data[1:]
binary.LittleEndian.PutUint16(encdat[:2], uint16(len(x509rsapubkey)))
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err := h.Write(x509rsapubkey)
if err != nil {
return nil, err
}
_ = h.Sum(encdat[2 : 2 : 2+8])
copy(encdat[2+8:], x509rsapubkey)
err = EncryptAESInplace(aescipher, encdat)
if err != nil {
return nil, err
}
return data, nil
}
// ParsePacketInit parse a init packet
func ParsePacketInit(aescipher cipher.Block, d []byte) (x509rsapubkey []byte, err error) {
if len(d) == 0 {
err = errors.New("invalid init packet length")
return
}
if d[0] >= byte(PacketInitTypeTop) {
err = errors.New("invalid init packet type")
return
}
if d[0] == byte(PacketInitTypeReq) {
return
}
data := DecryptAES(aescipher, d[1:])
klen := binary.LittleEndian.Uint16(data[:2])
if int(klen) > len(data[10:]) {
err = errors.New("invalid init packet data length")
return
}
x509rsapubkey = data[10 : 10+klen]
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err = h.Write(x509rsapubkey)
if err != nil {
return nil, err
}
var buf [8]byte
if !bytes.Equal(data[2:2+8], h.Sum(buf[:0])) {
err = errors.New("invalid init packet data")
return
}
return
}

49
utils/packet_init_test.go Normal file
View File

@@ -0,0 +1,49 @@
package utils
import (
"bytes"
"crypto/aes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"testing"
)
func TestPacketInit(t *testing.T) {
var buf [32]byte
_, err := rand.Read(buf[:])
if err != nil {
t.Fatal(err)
}
k, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatal(err)
}
aescipher, err := aes.NewCipher(buf[:])
if err != nil {
t.Fatal(err)
}
rsak := x509.MarshalPKCS1PublicKey(&k.PublicKey)
data, err := NewPacketInit(aescipher, rsak)
if err != nil {
t.Fatal(err)
}
parsedk, err := ParsePacketInit(aescipher, data)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(rsak, parsedk) {
t.Fatal("unexpected 1")
}
data, err = NewPacketInit(nil, nil)
if err != nil {
t.Fatal(err)
}
parsedk, err = ParsePacketInit(aescipher, data)
if err != nil {
t.Fatal(err)
}
if len(parsedk) > 0 {
t.Fatal("unexpected 2")
}
}

56
utils/rsa.go Normal file
View File

@@ -0,0 +1,56 @@
package utils
import (
"crypto/rsa"
"encoding/binary"
"errors"
"hash/crc64"
"math"
_ "unsafe"
)
//go:linkname encrypt crypto/rsa.encrypt
func encrypt(pub *rsa.PublicKey, plaintext []byte) ([]byte, error)
//go:linkname decrypt crypto/rsa.decrypt
func decrypt(priv *rsa.PrivateKey, ciphertext []byte, check bool) ([]byte, error)
// RSAPrivateKeyEncrypt use the method generally in sign
func RSAPrivateKeyEncrypt(priv *rsa.PrivateKey, plaintext []byte) ([]byte, error) {
if len(plaintext) > math.MaxUint16 {
return nil, errors.New("plaintext too large")
}
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err := h.Write(plaintext)
if err != nil {
return nil, err
}
data := make([]byte, len(plaintext)+8+2)
n := copy(data[:], plaintext)
binary.LittleEndian.PutUint64(data[n:n+8], h.Sum64())
binary.LittleEndian.PutUint16(data[n+8:n+8+2], uint16(len(plaintext)))
return decrypt(priv, data, false)
}
// RSAPublicKeyDecrypt use the method generally in sign
func RSAPublicKeyDecrypt(pub *rsa.PublicKey, ciphertext []byte) ([]byte, error) {
data, err := encrypt(pub, ciphertext)
if err != nil {
return nil, err
}
n := binary.LittleEndian.Uint16(data[len(data)-2:])
p := len(data) - int(n) - 8 - 2
if p < 0 || p > len(data) {
return nil, errors.New("invalid ciphertext length")
}
data = data[p:]
h := crc64.New(crc64.MakeTable(crc64.ECMA))
_, err = h.Write(data[:len(data)-8-2])
if err != nil {
return nil, err
}
if h.Sum64() != binary.LittleEndian.Uint64(data[len(data)-8-2:len(data)-2]) {
return nil, errors.New("invalid ciphertext")
}
return data[:len(data)-8-2], nil
}

27
utils/rsa_test.go Normal file
View File

@@ -0,0 +1,27 @@
package utils
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"testing"
)
func TestRSA(t *testing.T) {
testtext := []byte("test RSAPrivateKeyEncrypt")
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
enc, err := RSAPrivateKeyEncrypt(priv, testtext)
if err != nil {
t.Fatal(err)
}
dec, err := RSAPublicKeyDecrypt(&priv.PublicKey, enc)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(testtext, dec) {
t.Fatal("expected", string(testtext), "but got", string(dec))
}
}