diff --git a/cmd.go b/cmd.go index 478e212..b07db3e 100644 --- a/cmd.go +++ b/cmd.go @@ -2,6 +2,7 @@ package registry import ( "crypto/md5" + "errors" "unsafe" tea "github.com/fumiama/gofastTEA" @@ -18,22 +19,36 @@ const ( CMDDAT ) +var ( + ErrMd5Mismatch = errors.New("cmdpacket.decrypt: md5 mismatch") +) + type CmdPacket struct { - cmd uint8 - md5 [16]byte t tea.TEA data []byte + rawCmdPacket } +type rawCmdPacket struct { + cmd uint8 + len uint8 + md5 [16]byte + raw [255]byte // raw will expand to len +} + +//go:nosplit func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket { return &CmdPacket{ - cmd: cmd, - md5: md5.Sum(data), t: *t, data: data, + rawCmdPacket: rawCmdPacket{ + cmd: cmd, + md5: md5.Sum(data), + }, } } +//go:nosplit func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket { if len(data) < 1+1+16 { return nil @@ -41,34 +56,41 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket { if len(data)-1-1-16 < int(data[1]) { return nil } - var md5 [16]byte - copy(md5[:], data[2:18]) - return &CmdPacket{ - cmd: data[0], - md5: md5, - t: *t, - data: data[18 : data[1]+18], + r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data))) + c := &CmdPacket{ + t: *t, + rawCmdPacket: rawCmdPacket{ + cmd: r.cmd, + len: r.len, + md5: r.md5, + }, } + copy(c.raw[:], data[1+1+16:]) + return c } +//go:nosplit func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) { setseq(&c.t, seq) - d := c.t.EncryptLittleEndian(c.data, sumtable) - raw = append(raw, c.cmd, uint8(len(d))) - raw = append(raw, c.md5[:]...) - raw = append(raw, d...) + c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:])) + (*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket) + (*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len) + (*slice)(unsafe.Pointer(&raw)).Cap = 1 + 1 + 16 + 255 return } -func (c *CmdPacket) Decrypt(seq uint8) []byte { +//go:nosplit +func (c *CmdPacket) Decrypt(seq uint8) error { setseq(&c.t, seq) - d := c.t.DecryptLittleEndian(c.data, sumtable) + d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable) if d != nil && c.md5 == md5.Sum(d) { - return d + c.data = d + return nil } - return nil + return ErrMd5Mismatch } +//go:nosplit func setseq(t *tea.TEA, seq uint8) { *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq } diff --git a/cmd_test.go b/cmd_test.go index f3c27bd..b33ce59 100644 --- a/cmd_test.go +++ b/cmd_test.go @@ -18,7 +18,12 @@ func TestCmdPacket(t *testing.T) { p := NewCmdPacket(CMDGET, []byte("test"), &tp) conn.Write(p.Encrypt(seq)) seq++ - a := string(ack(t, conn, &tp).Decrypt(seq)) + ackp := ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a := string(ackp.data) t.Log(a) if a != "null" { t.Fail() @@ -27,7 +32,13 @@ func TestCmdPacket(t *testing.T) { p = NewCmdPacket(CMDSET, []byte("test"), &ts) conn.Write(p.Encrypt(seq)) seq++ - a = string(ack(t, conn, &tp).Decrypt(seq)) + + ackp = ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a = string(ackp.data) t.Log(a) if a != "data" { t.Fail() @@ -36,7 +47,13 @@ func TestCmdPacket(t *testing.T) { p = NewCmdPacket(CMDDAT, []byte("测试"), &ts) conn.Write(p.Encrypt(seq)) seq++ - a = string(ack(t, conn, &tp).Decrypt(seq)) + + ackp = ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a = string(ackp.data) t.Log(a) if a != "succ" { t.Fail() @@ -45,7 +62,13 @@ func TestCmdPacket(t *testing.T) { p = NewCmdPacket(CMDGET, []byte("test"), &tp) conn.Write(p.Encrypt(seq)) seq++ - a = string(ack(t, conn, &tp).Decrypt(seq)) + + ackp = ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a = string(ackp.data) t.Log(a) if a != "测试" { t.Fail() @@ -54,7 +77,13 @@ func TestCmdPacket(t *testing.T) { p = NewCmdPacket(CMDDEL, []byte("test"), &ts) conn.Write(p.Encrypt(seq)) seq++ - a = string(ack(t, conn, &tp).Decrypt(seq)) + + ackp = ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a = string(ackp.data) t.Log(a) if a != "succ" { t.Fail() @@ -63,7 +92,13 @@ func TestCmdPacket(t *testing.T) { p = NewCmdPacket(CMDGET, []byte("test"), &tp) conn.Write(p.Encrypt(seq)) seq++ - a = string(ack(t, conn, &tp).Decrypt(seq)) + + ackp = ack(t, conn, &tp) + err = ackp.Decrypt(seq) + if err != nil { + t.Fatal(err) + } + a = string(ackp.data) t.Log(a) if a != "null" { t.Fail() diff --git a/go.mod b/go.mod index 0e53132..fef6818 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/fumiama/go-registry go 1.17 -require github.com/fumiama/gofastTEA v0.0.7 +require github.com/fumiama/gofastTEA v0.0.9 diff --git a/go.sum b/go.sum index 848be59..4a0f386 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/fumiama/gofastTEA v0.0.7 h1:Dbce+22jNm+7jpmAeju0C+esIAAnIrq75f5TsCYprS8= -github.com/fumiama/gofastTEA v0.0.7/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= +github.com/fumiama/gofastTEA v0.0.9 h1:adaWz+014vMShnLUNWIHLBs0Yv6JNUohcaXZNtct5J0= +github.com/fumiama/gofastTEA v0.0.9/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= diff --git a/reg.go b/reg.go index 14c8177..df92338 100644 --- a/reg.go +++ b/reg.go @@ -68,11 +68,11 @@ func (r *Regedit) Get(key string) (string, error) { if err != nil { return "", err } - ackbytes := ack.Decrypt(r.seq) - if ackbytes == nil { + err = ack.Decrypt(r.seq) + if err != nil { return "", errors.New("decrypt ack error") } - a := BytesToString(ackbytes) + a := BytesToString(ack.data) r.seq++ if a == "erro" { return "", errors.New("server ack error") @@ -100,11 +100,11 @@ func (r *Regedit) Set(key, value string) error { if err != nil { return err } - ackbytes := ack.Decrypt(r.seq) - if ackbytes == nil { + err = ack.Decrypt(r.seq) + if err != nil { return errors.New("decrypt ack error") } - a := BytesToString(ackbytes) + a := BytesToString(ack.data) r.seq++ if a == "erro" { return errors.New("server ack error") @@ -119,11 +119,11 @@ func (r *Regedit) Set(key, value string) error { if err != nil { return err } - ackbytes = ack.Decrypt(r.seq) - if ackbytes == nil { + err = ack.Decrypt(r.seq) + if err != nil { return errors.New("decrypt ack error") } - a = BytesToString(ackbytes) + a = BytesToString(ack.data) r.seq++ if a == "erro" { return errors.New("server ack error") @@ -148,11 +148,11 @@ func (r *Regedit) Del(key string) error { if err != nil { return err } - ackbytes := ack.Decrypt(r.seq) - if ackbytes == nil { + err = ack.Decrypt(r.seq) + if err != nil { return errors.New("decrypt ack error") } - a := BytesToString(ackbytes) + a := BytesToString(ack.data) r.seq++ if a == "erro" { return errors.New("server ack error")