commit f79b275f420e12582e09a946da3f48b80284cbcd Author: 源文雨 <41315874+fumiama@users.noreply.github.com> Date: Fri Dec 29 16:29:41 2023 +0900 init diff --git a/README.md b/README.md new file mode 100644 index 0000000..4480206 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# C/S demo with RSA and AES Algorithm + +This program is my homework for the IoT course, demonstrating the use of symmetric encryption AES and asymmetric encryption RSA. + +## Communication Process +### Server +1. Create a socket and listen for TCP connections. +2. Upon receiving a connection, perform corresponding operations based on the client's request: + - `utils.PacketTypeInit`: Encrypt its own RSA public key with the pre-shared AES key and send it to the client. + - `utils.PacketTypeComm`: Decrypt and print the message sent by the client using `RSA-OAEP` with the private key, then encrypt its replyment with a **self-made** signature algorithm using the RSA private key and send it. +3. When the client actively disconnects, one processing ends. + +### Client +1. Establish a TCP connection to the server address. +2. Send a `utils.PacketTypeInit` packet to the server to request the RSA key. +3. Continuously send `Hello` to the server using `EncryptOAEP` while decrypting and printing the message sent by server with its public key using the **self-made** signature algorithm. +4. When the user manually terminates, close the connection to the server. + +## Packet Protocol + +The overall encapsulation is defined in [utils/packet.go](utils/packet.go). +``` +0 15 23 +┌─────────┬────┬──────────────┐ +│ len │type│ ... data ... │ +└─────────┴────┴──────────────┘ +``` + +### len +The length of the packet without itself. +### type +Defined in [utils/packet.go](utils/packet.go). +```go +const ( + PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key + PacketTypeComm // PacketTypeComm normal communication + PacketTypeTop // PacketTypeTop for valid checking +) +``` +### data +The payload, whose type is described by the `type` field. +#### PacketTypeInit +Defined in [utils/packet_init.go](utils/packet_init.go). +``` +0 7 23 87 +┌────┬────────┬─────────────────┬──────────────────────┐ +│type│ length │ pub key crc64 │ x509rsapubkey │ +└────┴────────┴─────────────────┴──────────────────────┘ +``` +- **type**: `PacketInitTypeReq` or `PacketInitTypeAck` +- **length**: length of `x509rsapubkey` +#### PacketTypeComm +The whole data field is encrypted by RSA and can fill in with any data, which is plain text in this demo. + +## Interesting Points +### The Implementation of [Base16384](https://github.com/fumiama/base16384) +Base16384 is a base64-like algorithm designed by me. It can encode binary file to printable utf16be, and vice versa. + +In this demo, the [RSA Private Key](rsa_2048_private_x509.b14) and AES key is saved and passed by base16384 format. +### The Usage of Raw RSA Encrypting Method +In the file [utils/rsa.go](utils/rsa.go), I use `go:linkname` to hook the private function of `crypto/rsa` library and realized a **self-made** signature algorithm that can get the decoding result but not just verify whether it is valid (unlike the official method `rsa.VerifyPKCS1v15`). +## Demo +See the video below. diff --git a/client/main.go b/client/main.go new file mode 100644 index 0000000..159272b --- /dev/null +++ b/client/main.go @@ -0,0 +1,119 @@ +package main + +import ( + "crypto/aes" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "flag" + "fmt" + "net" + "net/netip" + "socket/utils" + "time" + + base14 "github.com/fumiama/go-base16384" +) + +func main() { + server := flag.String("s", "0.0.0.0:12345", "server host:port") + aespsh := flag.String("aes", "矲晊泞柯碍嘖另蚡蔼帀㴂", "the AES preshard key of base16384 format") + flag.Parse() + if *aespsh == "" { + fmt.Println("must give parameter -aes") + return + } + aescipher, err := aes.NewCipher(base14.DecodeFromString(*aespsh)) + if err != nil { + fmt.Println("parse AES preshared key error:", err) + return + } + conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*server))) + if err != nil { + fmt.Println("Connect to server error:", err) + return + } + defer conn.Close() + fmt.Println("Ronnected to server", conn.RemoteAddr()) + // request rsa pubkey + packet := utils.Packet{} + packet.Typ = utils.PacketTypeInit + packet.Dat, err = utils.NewPacketInit(nil, nil) + if err != nil { + fmt.Println("Wrap init packet err:", err) + return + } + data, err := packet.ToBytes() + if err != nil { + fmt.Println("Wrap packet err:", err) + return + } + _, err = conn.Write(data) + if err != nil { + fmt.Println("Write to server", conn.RemoteAddr(), "error:", err) + return + } + packet, err = utils.ReadPacket(conn) + if err != nil { + fmt.Println("Read packet from server", conn.RemoteAddr(), "error:", err) + return + } + if packet.Typ != utils.PacketTypeInit { + fmt.Println("Unexpected packet type from server") + return + } + x509rsapubkey, err := utils.ParsePacketInit(aescipher, packet.Dat) + if err != nil { + fmt.Println("Parse init packet error:", err) + return + } + rsapubk, err := x509.ParsePKCS1PublicKey(x509rsapubkey) + if err != nil { + fmt.Println("Parse x509rsapubkey error:", err) + return + } + fmt.Println("Get x509rsapubkey successfully") + t := time.NewTicker(time.Second) + defer t.Stop() + count := 0 + for range t.C { + count++ + packet.Typ = utils.PacketTypeComm + data, err = rsa.EncryptOAEP( + md5.New(), rand.Reader, rsapubk, + base14.StringToBytes(fmt.Sprintf("Hello! This is my No.%d communication.", count)), + nil, + ) + if err != nil { + fmt.Println("EncryptOAEP err:", err) + return + } + packet.Dat = data + data, err := packet.ToBytes() + if err != nil { + fmt.Println("Wrap packet err:", err) + return + } + _, err = conn.Write(data) + if err != nil { + fmt.Println("Write to server", conn.RemoteAddr(), "error:", err) + return + } + packet, err = utils.ReadPacket(conn) + if err != nil { + fmt.Println("Read packet from server", conn.RemoteAddr(), "error:", err) + return + } + if err != nil { + fmt.Println("Receive from server", conn.RemoteAddr(), "error:", err) + continue + } + data, err = utils.RSAPublicKeyDecrypt(rsapubk, packet.Dat) + if err != nil { + fmt.Println("RSAPublicKeyDecrypt from server", conn.RemoteAddr(), "error:", err) + continue + } + fmt.Println("Receive from server", conn.RemoteAddr(), ":", base14.BytesToString(data)) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3f2a149 --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module socket + +go 1.17 + +require github.com/fumiama/go-base16384 v1.7.0 + +require golang.org/x/text v0.3.7 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f79a7b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,15 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fumiama/go-base16384 v1.7.0 h1:6fep7XPQWxRlh4Hu+KsdH+6+YdUp+w6CwRXtMWSsXCA= +github.com/fumiama/go-base16384 v1.7.0/go.mod h1:OEn+947GV5gsbTAnyuUW/SrfxJYUdYupSIQXOuGOcXM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rsa_2048_private_x509.b14 b/rsa_2048_private_x509.b14 new file mode 100644 index 0000000..944695d --- /dev/null +++ b/rsa_2048_private_x509.b14 @@ -0,0 +1 @@ +娠湊戈伀亠渐刂裭甯諊劜帱偒傡睔愲撗厼豚樥瞛茪囧蟑圞憟剔堬吁穲箂讄妒拯洍吡寰忡单汪皮寁藉襂竒砹帐嵗诎笵哠茝傉愇奣賆繽帹罔团聥艥廭矑灱炁璨瓌稲剽汄侫斌漺穣灈廞姅换坣諒讻蚜势詈妙萁襴敇蒌狫嵞堨櫣苫姜盾洽暄繕蚩櫞蚚病絼恍灪厴涪丵洔澋谛帥樋变傛撮仺媹趥层窠枚瀅侂挃痁吐蜧甦寶蚊捾眐杝笜慤悀赧缭祘嶷蛜膉歅呸慏瘘灤蘺禓玝縠娄丁亠渐刂僑牵硏攼矄篤膦綵嘟丈仍脶罿盹讫崿弦楶惑悰诏尃揤統夥泘脉哊璜膬恴夠苨汼筹卧咫璘勏徒擼旾厣撆簂噯嗱灄呔繇砊挋枒坭岛国摂嗨痲嵻羂朸浖茙謝璌旹弤蛫烵壣务瀕懘掠怩屆脕焁嫅詇肷洽皈啃晻摹墆竒杰牨諝卛罣猟茂嫔匳岱诩受潙芲誋皪臜徑倁慘庿擈唅廩綵仛簟蒱盍視喖瓡历僅云妙庂兘渷乕御咍职碟蚕濭綶竧瀎寋楌吴嚾苄謷杶赒帨各仧幵噼搯豐竷珳敇暝批豛曂囔穚獙厵舫懱椣跴滭瑼妃臱沚蝣屗缼氌燍褿眎濓翡堬蟾敜敟蓭诮挲牮蒮裚窱摯拉倱搏俀湫炕民喋因貂掳眴捡圫団嵐秶巹褬寘穐覆讐俖袒差矞蓀瘘刃峭璯昏耚罖桅蚑撂犺拨嚋卬蝉緭囊茔牥忞獘呏箷瑺譡翗帠众抶浢緍楼咨暏捽啴茜暛揓羜诠瀼懜哜岫纅詝煑臥傳婲塟挼廎社刦勥瀻諨蠥檍腓瀈挧纞趝嶥媔趮穝举傁彺旇膕亠昁絠葑薯厷峘犙愋硲獈桜諔汫洵滣悫瑢掳瑽覻喉名岓稛偪緧玼嗾瞬傉濯疮蜿瑻荞橁奞啸舠芜殾昧樊坄渻穫萋偀囌云苟专眐巢漞絳嵑班嗥搌嗯蓕覚護創垲櫌蛗孎猦瞄筆蔓璼傁湀圇嚐葬篶痢漞秺婜祠碬毀腫捄绫栙衒淪嘑啍娺穦嘞訬變嵶岈癍薺諺肌菓螕荔紕砢俉勍欰嫍腐褲挺聿衖衪煝栈剸嫜杍旽何寁僌櫞櫲臉瀣焦叨愳氕袏幏刬灩欭垭猄藹嚗絚胁刊侁临怺墖拏燕趝誤譙惫娎犄叚卄帩私蜠蘀娪紨搎咇佪熸蒉蜶淎刼蜟哆捧株溄祭婓坐莛嚐嵝碳睒蔯僕栣芪厵皴袤覵祭朰潭幄櫸藶蔮萄賁簉併苺窩暽潒畩勓耆栀嗅怳曩赎竨痦穠㴃 \ No newline at end of file diff --git a/server/main.go b/server/main.go new file mode 100644 index 0000000..889e7a4 --- /dev/null +++ b/server/main.go @@ -0,0 +1,180 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "flag" + "fmt" + "io" + "net" + "net/netip" + "os" + "socket/utils" + "strconv" + "sync/atomic" + + base14 "github.com/fumiama/go-base16384" +) + +var countnum = uintptr(0) + +func main() { + listen := flag.String("l", "0.0.0.0:12345", "listening host:port") + genaespsh := flag.String("gaes", "", "generate a new AES preshard key (AES-128, AES-192, or AES-256)") + genrsakey := flag.Uint("grsa", 0, "generate a new RSA key pair of bits and save it to rsa_bits_private_x509.b14") + rsapkfile := flag.String("rsaf", "rsa_2048_private_x509.b14", "specify the path of the RSA private key") + aespsh := flag.String("aes", "矲晊泞柯碍嘖另蚡蔼帀㴂", "the AES preshard key of base16384 format") + flag.Parse() + if *genaespsh != "" { + k, err := utils.NewAESPresharedKey(*genaespsh) + if err != nil { + fmt.Println("Generate new AES preshard key error:", err) + return + } + fmt.Println("The generated AES preshared key is:", base14.EncodeToString(k)) + return + } + if *genrsakey != 0 { + k, err := rsa.GenerateKey(rand.Reader, int(*genrsakey)) + if err != nil { + fmt.Println("Generate new RSA key pair error:", err) + return + } + f, err := os.Create(fmt.Sprintf("rsa_%d_private_x509.b14", *genrsakey)) + if err != nil { + fmt.Println("Create new RSA private key error:", err) + return + } + defer f.Close() + _, err = f.WriteString(base14.EncodeToString(x509.MarshalPKCS1PrivateKey(k))) + if err != nil { + fmt.Println("Save new RSA private key error:", err) + return + } + fmt.Println("The generated RSA key pair is:") + fmt.Println("Private key: saved into", f.Name()) + fmt.Println("Public key:", base14.EncodeToString(x509.MarshalPKCS1PublicKey(&k.PublicKey))) + return + } + rsapkeydata, err := os.ReadFile(*rsapkfile) + if err != nil { + fmt.Println("Read RSA private key error:", err) + return + } + rsaprivkey, err := x509.ParsePKCS1PrivateKey(base14.DecodeFromString(base14.BytesToString(rsapkeydata))) + if err != nil { + fmt.Println("X509 parse RSA private key error:", err) + return + } + x509rsapubkey := x509.MarshalPKCS1PublicKey(&rsaprivkey.PublicKey) + if *aespsh == "" { + fmt.Println("must give parameter -aes") + return + } + aescipher, err := aes.NewCipher(base14.DecodeFromString(*aespsh)) + if err != nil { + fmt.Println("parse AES preshared key error:", err) + return + } + + listener, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*listen))) + if err != nil { + fmt.Println("ListenTCP error:", err) + return + } + defer listener.Close() + fmt.Println("Server bind and listen on", listener.Addr()) + + for { + conn, err := listener.AcceptTCP() + if err != nil { + fmt.Println("Accept error:", err) + break + } + fmt.Println("Server accept connection from", conn.RemoteAddr()) + go handleclient(conn, aescipher, x509rsapubkey, rsaprivkey) + } +} + +func handleclient(conn *net.TCPConn, aescipher cipher.Block, x509rsapubkey []byte, rsaprivkey *rsa.PrivateKey) { + var packet utils.Packet + var err error + defer conn.Close() + for { + packet, err = utils.ReadPacket(conn) + if err == io.EOF { + fmt.Println("Client", conn.RemoteAddr(), "closed connection") + return + } + if err != nil { + fmt.Println("Read packet from client", conn.RemoteAddr(), "error:", err) + return + } + switch packet.Typ { + case utils.PacketTypeInit: + rsakeydata, err := utils.ParsePacketInit(aescipher, packet.Dat) + if err != nil { + fmt.Println("Parse init packet from client", conn.RemoteAddr(), "error:", err) + return + } + if len(rsakeydata) > 0 { // unexpected situation + return + } + // send RSA public key by AES encryption + packet.Typ = utils.PacketTypeInit + packet.Dat, err = utils.NewPacketInit(aescipher, x509rsapubkey) + if err != nil { + fmt.Println("Wrap RSA public key init packet to client", conn.RemoteAddr(), "error:", err) + return + } + data, err := packet.ToBytes() + if err != nil { + fmt.Println("Wrap packet to client", conn.RemoteAddr(), "error:", err) + return + } + _, err = conn.Write(data) + if err != nil { + fmt.Println("Send RSA public key to client", conn.RemoteAddr(), "error:", err) + return + } + continue + case utils.PacketTypeComm: + data, err := rsa.DecryptOAEP(md5.New(), rand.Reader, rsaprivkey, packet.Dat, nil) + if err != nil { + fmt.Println("DecryptOAEP from client", conn.RemoteAddr(), "error:", err) + return + } + fmt.Println("Recv data from client", conn.RemoteAddr(), ":", base14.BytesToString(data)) + packet.Typ = utils.PacketTypeComm + data, err = utils.RSAPrivateKeyEncrypt( + rsaprivkey, base14.StringToBytes( + "Thank you for connecting! The data is "+ + strconv.Itoa(int(atomic.AddUintptr(&countnum, 1))), + ), + ) + if err != nil { + fmt.Println("RSAPrivateKeyEncrypt to client", conn.RemoteAddr(), "error:", err) + return + } + packet.Dat = data + data, err = packet.ToBytes() + if err != nil { + fmt.Println("Wrap packet to client", conn.RemoteAddr(), "error:", err) + continue + } + _, err = conn.Write(data) + if err != nil { + fmt.Println("Send to client", conn.RemoteAddr(), "error:", err) + continue + } + continue + default: + fmt.Println("Recv unknown packet type from client") + return + } + } +} diff --git a/utils/aes.go b/utils/aes.go new file mode 100644 index 0000000..ebdb4a6 --- /dev/null +++ b/utils/aes.go @@ -0,0 +1,94 @@ +package utils + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" +) + +// NewAESPresharedKey typ AES-128, AES-192, or AES-256. +func NewAESPresharedKey(typ string) ([]byte, error) { + sz := 0 + switch typ { + case "AES-128": + sz = 16 + case "AES-192": + sz = 24 + case "AES-256": + sz = 32 + } + if sz <= 0 { + return nil, aes.KeySizeError(sz) + } + k := make([]byte, sz) + _, err := rand.Read(k) + if err != nil { + return nil, err + } + return k, nil +} + +// EncryptAES ... +func EncryptAES(aescipher cipher.Block, data []byte) []byte { + blksz := aescipher.BlockSize() + total := len(data) + n := total / blksz + if total%blksz > 0 { + n++ + } + encdat := make([]byte, blksz*n) + copy(encdat, data) + for i := 0; i < n; i++ { + a := i * blksz + b := (i + 1) * blksz + aescipher.Encrypt(encdat[a:b], encdat[a:b]) + } + return encdat +} + +// EncryptAESInplace ... +func EncryptAESInplace(aescipher cipher.Block, data []byte) error { + blksz := aescipher.BlockSize() + total := len(data) + n := total / blksz + if total%blksz > 0 { + n++ + } + if len(data) < blksz*n { + return errors.New("data is too short") + } + for i := 0; i < n; i++ { + a := i * blksz + b := (i + 1) * blksz + aescipher.Encrypt(data[a:b], data[a:b]) + } + return nil +} + +// DecryptAES ... +func DecryptAES(aescipher cipher.Block, data []byte) []byte { + blksz := aescipher.BlockSize() + total := len(data) + n := total / blksz + decdat := make([]byte, blksz*n) + for i := 0; i < n; i++ { + a := i * blksz + b := (i + 1) * blksz + aescipher.Decrypt(decdat[a:b], data[a:b]) + } + return decdat +} + +// DecryptAESInplace ... +func DecryptAESInplace(aescipher cipher.Block, data []byte) error { + blksz := aescipher.BlockSize() + total := len(data) + n := total / blksz + for i := 0; i < n; i++ { + a := i * blksz + b := (i + 1) * blksz + aescipher.Decrypt(data[a:b], data[a:b]) + } + return nil +} diff --git a/utils/aes_test.go b/utils/aes_test.go new file mode 100644 index 0000000..6638401 --- /dev/null +++ b/utils/aes_test.go @@ -0,0 +1,60 @@ +package utils + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + "encoding/hex" + "testing" +) + +func TestAES(t *testing.T) { + var buf [32]byte + var data [123]byte + _, err := rand.Read(buf[:]) + if err != nil { + t.Fatal(err) + } + _, err = rand.Read(data[:]) + if err != nil { + t.Fatal(err) + } + aescipher, err := aes.NewCipher(buf[:]) + if err != nil { + t.Fatal(err) + } + encdat := EncryptAES(aescipher, data[:]) + decdat := DecryptAES(aescipher, encdat)[:123] + if !bytes.Equal(data[:], decdat) { + t.Fatal("expected " + hex.EncodeToString(data[:]) + " but got " + hex.EncodeToString(decdat)) + } +} + +func TestAESInplace(t *testing.T) { + var buf [32]byte + var data [128]byte + _, err := rand.Read(buf[:]) + if err != nil { + t.Fatal(err) + } + _, err = rand.Read(data[:]) + if err != nil { + t.Fatal(err) + } + aescipher, err := aes.NewCipher(buf[:]) + if err != nil { + t.Fatal(err) + } + org := data + err = EncryptAESInplace(aescipher, data[:]) + if err != nil { + t.Fatal(err) + } + err = DecryptAESInplace(aescipher, data[:]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data[:], org[:]) { + t.Fatal("expected " + hex.EncodeToString(org[:]) + " but got " + hex.EncodeToString(data[:])) + } +} diff --git a/utils/packet.go b/utils/packet.go new file mode 100644 index 0000000..ec857eb --- /dev/null +++ b/utils/packet.go @@ -0,0 +1,77 @@ +package utils + +import ( + "encoding/binary" + "errors" + "io" + "math" +) + +type PacketType uint8 + +const ( + PacketTypeInit PacketType = iota // PacketTypeInit pass RSA pubkey by AES pre-shared key + PacketTypeComm // PacketTypeComm normal communication + PacketTypeTop // PacketTypeTop for valid checking +) + +// Packet is the communicating proxy +type Packet struct { + Len uint16 // Len packet length LE + Typ PacketType // Typ packet type + Dat []byte // Dat payload +} + +// ParsePacket from bytes +func ParsePacket(d []byte) (p Packet, err error) { + l := binary.LittleEndian.Uint16(d[:2]) + if 2+int(l) != len(d) { + err = errors.New("invalid packet len") + return + } + p.Len = l + p.Typ = PacketType(d[2]) + p.Dat = d[3:] + return +} + +// ReadPacket from io.Reader +func ReadPacket(r io.Reader) (p Packet, err error) { + var buf [2]byte + _, err = io.ReadFull(r, buf[:]) + if err != nil { + return + } + l := binary.LittleEndian.Uint16(buf[:]) + data := make([]byte, l) + _, err = io.ReadFull(r, data) + if err != nil { + return + } + p.Len = l + p.Typ = PacketType(data[0]) + p.Dat = data[1:] + return +} + +// ToBytes marshal packet into bytes +func (p *Packet) ToBytes() ([]byte, error) { + l := 1 + len(p.Dat) + if l > math.MaxUint16 { + return nil, errors.New("packet data too large") + } + if p.Typ >= PacketTypeTop { + return nil, errors.New("invalid packet tpye") + } + p.Len = uint16(l) + return p.MustToBytes(), nil +} + +// MustToBytes don't do any check +func (p *Packet) MustToBytes() []byte { + data := make([]byte, 2+1+len(p.Dat)) + binary.LittleEndian.PutUint16(data[:2], p.Len) + data[2] = byte(p.Typ) + copy(data[3:], p.Dat) + return data +} diff --git a/utils/packet_init.go b/utils/packet_init.go new file mode 100644 index 0000000..841ff2f --- /dev/null +++ b/utils/packet_init.go @@ -0,0 +1,79 @@ +package utils + +import ( + "bytes" + "crypto/cipher" + "encoding/binary" + "errors" + "hash/crc64" +) + +type PacketInitType uint8 + +const ( + PacketInitTypeReq PacketInitType = iota // PacketInitTypeReq request RSA pubkey (by client) + PacketInitTypeAck // PacketInitTypeAck give the key + PacketInitTypeTop +) + +// NewPacketInit x509rsapubkey = nil for req +func NewPacketInit(aescipher cipher.Block, x509rsapubkey []byte) ([]byte, error) { + if len(x509rsapubkey) == 0 { + return []byte{byte(PacketInitTypeReq)}, nil + } + blksz := aescipher.BlockSize() + total := 2 + len(x509rsapubkey) + 8 + n := total / blksz + if total%blksz > 0 { + n++ + } + data := make([]byte, 1+blksz*n) + data[0] = byte(PacketInitTypeAck) + encdat := data[1:] + binary.LittleEndian.PutUint16(encdat[:2], uint16(len(x509rsapubkey))) + h := crc64.New(crc64.MakeTable(crc64.ECMA)) + _, err := h.Write(x509rsapubkey) + if err != nil { + return nil, err + } + _ = h.Sum(encdat[2 : 2 : 2+8]) + copy(encdat[2+8:], x509rsapubkey) + err = EncryptAESInplace(aescipher, encdat) + if err != nil { + return nil, err + } + return data, nil +} + +// ParsePacketInit parse a init packet +func ParsePacketInit(aescipher cipher.Block, d []byte) (x509rsapubkey []byte, err error) { + if len(d) == 0 { + err = errors.New("invalid init packet length") + return + } + if d[0] >= byte(PacketInitTypeTop) { + err = errors.New("invalid init packet type") + return + } + if d[0] == byte(PacketInitTypeReq) { + return + } + data := DecryptAES(aescipher, d[1:]) + klen := binary.LittleEndian.Uint16(data[:2]) + if int(klen) > len(data[10:]) { + err = errors.New("invalid init packet data length") + return + } + x509rsapubkey = data[10 : 10+klen] + h := crc64.New(crc64.MakeTable(crc64.ECMA)) + _, err = h.Write(x509rsapubkey) + if err != nil { + return nil, err + } + var buf [8]byte + if !bytes.Equal(data[2:2+8], h.Sum(buf[:0])) { + err = errors.New("invalid init packet data") + return + } + return +} diff --git a/utils/packet_init_test.go b/utils/packet_init_test.go new file mode 100644 index 0000000..ac06739 --- /dev/null +++ b/utils/packet_init_test.go @@ -0,0 +1,49 @@ +package utils + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "testing" +) + +func TestPacketInit(t *testing.T) { + var buf [32]byte + _, err := rand.Read(buf[:]) + if err != nil { + t.Fatal(err) + } + k, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatal(err) + } + aescipher, err := aes.NewCipher(buf[:]) + if err != nil { + t.Fatal(err) + } + rsak := x509.MarshalPKCS1PublicKey(&k.PublicKey) + data, err := NewPacketInit(aescipher, rsak) + if err != nil { + t.Fatal(err) + } + parsedk, err := ParsePacketInit(aescipher, data) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(rsak, parsedk) { + t.Fatal("unexpected 1") + } + data, err = NewPacketInit(nil, nil) + if err != nil { + t.Fatal(err) + } + parsedk, err = ParsePacketInit(aescipher, data) + if err != nil { + t.Fatal(err) + } + if len(parsedk) > 0 { + t.Fatal("unexpected 2") + } +} diff --git a/utils/rsa.go b/utils/rsa.go new file mode 100644 index 0000000..153a324 --- /dev/null +++ b/utils/rsa.go @@ -0,0 +1,56 @@ +package utils + +import ( + "crypto/rsa" + "encoding/binary" + "errors" + "hash/crc64" + "math" + _ "unsafe" +) + +//go:linkname encrypt crypto/rsa.encrypt +func encrypt(pub *rsa.PublicKey, plaintext []byte) ([]byte, error) + +//go:linkname decrypt crypto/rsa.decrypt +func decrypt(priv *rsa.PrivateKey, ciphertext []byte, check bool) ([]byte, error) + +// RSAPrivateKeyEncrypt use the method generally in sign +func RSAPrivateKeyEncrypt(priv *rsa.PrivateKey, plaintext []byte) ([]byte, error) { + if len(plaintext) > math.MaxUint16 { + return nil, errors.New("plaintext too large") + } + h := crc64.New(crc64.MakeTable(crc64.ECMA)) + _, err := h.Write(plaintext) + if err != nil { + return nil, err + } + data := make([]byte, len(plaintext)+8+2) + n := copy(data[:], plaintext) + binary.LittleEndian.PutUint64(data[n:n+8], h.Sum64()) + binary.LittleEndian.PutUint16(data[n+8:n+8+2], uint16(len(plaintext))) + return decrypt(priv, data, false) +} + +// RSAPublicKeyDecrypt use the method generally in sign +func RSAPublicKeyDecrypt(pub *rsa.PublicKey, ciphertext []byte) ([]byte, error) { + data, err := encrypt(pub, ciphertext) + if err != nil { + return nil, err + } + n := binary.LittleEndian.Uint16(data[len(data)-2:]) + p := len(data) - int(n) - 8 - 2 + if p < 0 || p > len(data) { + return nil, errors.New("invalid ciphertext length") + } + data = data[p:] + h := crc64.New(crc64.MakeTable(crc64.ECMA)) + _, err = h.Write(data[:len(data)-8-2]) + if err != nil { + return nil, err + } + if h.Sum64() != binary.LittleEndian.Uint64(data[len(data)-8-2:len(data)-2]) { + return nil, errors.New("invalid ciphertext") + } + return data[:len(data)-8-2], nil +} diff --git a/utils/rsa_test.go b/utils/rsa_test.go new file mode 100644 index 0000000..8eacfe4 --- /dev/null +++ b/utils/rsa_test.go @@ -0,0 +1,27 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "testing" +) + +func TestRSA(t *testing.T) { + testtext := []byte("test RSAPrivateKeyEncrypt") + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + enc, err := RSAPrivateKeyEncrypt(priv, testtext) + if err != nil { + t.Fatal(err) + } + dec, err := RSAPublicKeyDecrypt(&priv.PublicKey, enc) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testtext, dec) { + t.Fatal("expected", string(testtext), "but got", string(dec)) + } +}