1
0
mirror of https://github.com/fumiama/go-registry.git synced 2026-06-07 17:00:27 +08:00

优化速度

This commit is contained in:
fumiama
2022-02-15 00:38:43 +08:00
parent 0bb4b0bd81
commit e7c05d8bc8
5 changed files with 97 additions and 40 deletions

60
cmd.go
View File

@@ -2,6 +2,7 @@ package registry
import (
"crypto/md5"
"errors"
"unsafe"
tea "github.com/fumiama/gofastTEA"
@@ -18,22 +19,36 @@ const (
CMDDAT
)
var (
ErrMd5Mismatch = errors.New("cmdpacket.decrypt: md5 mismatch")
)
type CmdPacket struct {
cmd uint8
md5 [16]byte
t tea.TEA
data []byte
rawCmdPacket
}
type rawCmdPacket struct {
cmd uint8
len uint8
md5 [16]byte
raw [255]byte // raw will expand to len
}
//go:nosplit
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket {
return &CmdPacket{
cmd: cmd,
md5: md5.Sum(data),
t: *t,
data: data,
rawCmdPacket: rawCmdPacket{
cmd: cmd,
md5: md5.Sum(data),
},
}
}
//go:nosplit
func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
if len(data) < 1+1+16 {
return nil
@@ -41,34 +56,41 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
if len(data)-1-1-16 < int(data[1]) {
return nil
}
var md5 [16]byte
copy(md5[:], data[2:18])
return &CmdPacket{
cmd: data[0],
md5: md5,
t: *t,
data: data[18 : data[1]+18],
r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data)))
c := &CmdPacket{
t: *t,
rawCmdPacket: rawCmdPacket{
cmd: r.cmd,
len: r.len,
md5: r.md5,
},
}
copy(c.raw[:], data[1+1+16:])
return c
}
//go:nosplit
func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
setseq(&c.t, seq)
d := c.t.EncryptLittleEndian(c.data, sumtable)
raw = append(raw, c.cmd, uint8(len(d)))
raw = append(raw, c.md5[:]...)
raw = append(raw, d...)
c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:]))
(*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket)
(*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len)
(*slice)(unsafe.Pointer(&raw)).Cap = 1 + 1 + 16 + 255
return
}
func (c *CmdPacket) Decrypt(seq uint8) []byte {
//go:nosplit
func (c *CmdPacket) Decrypt(seq uint8) error {
setseq(&c.t, seq)
d := c.t.DecryptLittleEndian(c.data, sumtable)
d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable)
if d != nil && c.md5 == md5.Sum(d) {
return d
c.data = d
return nil
}
return nil
return ErrMd5Mismatch
}
//go:nosplit
func setseq(t *tea.TEA, seq uint8) {
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq
}

View File

@@ -18,7 +18,12 @@ func TestCmdPacket(t *testing.T) {
p := NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
a := string(ack(t, conn, &tp).Decrypt(seq))
ackp := ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a := string(ackp.data)
t.Log(a)
if a != "null" {
t.Fail()
@@ -27,7 +32,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDSET, []byte("test"), &ts)
conn.Write(p.Encrypt(seq))
seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "data" {
t.Fail()
@@ -36,7 +47,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDDAT, []byte("测试"), &ts)
conn.Write(p.Encrypt(seq))
seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "succ" {
t.Fail()
@@ -45,7 +62,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "测试" {
t.Fail()
@@ -54,7 +77,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDDEL, []byte("test"), &ts)
conn.Write(p.Encrypt(seq))
seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "succ" {
t.Fail()
@@ -63,7 +92,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "null" {
t.Fail()

2
go.mod
View File

@@ -2,4 +2,4 @@ module github.com/fumiama/go-registry
go 1.17
require github.com/fumiama/gofastTEA v0.0.7
require github.com/fumiama/gofastTEA v0.0.9

4
go.sum
View File

@@ -1,2 +1,2 @@
github.com/fumiama/gofastTEA v0.0.7 h1:Dbce+22jNm+7jpmAeju0C+esIAAnIrq75f5TsCYprS8=
github.com/fumiama/gofastTEA v0.0.7/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk=
github.com/fumiama/gofastTEA v0.0.9 h1:adaWz+014vMShnLUNWIHLBs0Yv6JNUohcaXZNtct5J0=
github.com/fumiama/gofastTEA v0.0.9/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk=

24
reg.go
View File

@@ -68,11 +68,11 @@ func (r *Regedit) Get(key string) (string, error) {
if err != nil {
return "", err
}
ackbytes := ack.Decrypt(r.seq)
if ackbytes == nil {
err = ack.Decrypt(r.seq)
if err != nil {
return "", errors.New("decrypt ack error")
}
a := BytesToString(ackbytes)
a := BytesToString(ack.data)
r.seq++
if a == "erro" {
return "", errors.New("server ack error")
@@ -100,11 +100,11 @@ func (r *Regedit) Set(key, value string) error {
if err != nil {
return err
}
ackbytes := ack.Decrypt(r.seq)
if ackbytes == nil {
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
}
a := BytesToString(ackbytes)
a := BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")
@@ -119,11 +119,11 @@ func (r *Regedit) Set(key, value string) error {
if err != nil {
return err
}
ackbytes = ack.Decrypt(r.seq)
if ackbytes == nil {
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
}
a = BytesToString(ackbytes)
a = BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")
@@ -148,11 +148,11 @@ func (r *Regedit) Del(key string) error {
if err != nil {
return err
}
ackbytes := ack.Decrypt(r.seq)
if ackbytes == nil {
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
}
a := BytesToString(ackbytes)
a := BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")