1
0
mirror of https://github.com/fumiama/go-registry.git synced 2026-06-24 04:30:28 +08:00

优化速度

This commit is contained in:
fumiama
2022-02-15 00:38:43 +08:00
parent 0bb4b0bd81
commit e7c05d8bc8
5 changed files with 97 additions and 40 deletions

60
cmd.go
View File

@@ -2,6 +2,7 @@ package registry
import ( import (
"crypto/md5" "crypto/md5"
"errors"
"unsafe" "unsafe"
tea "github.com/fumiama/gofastTEA" tea "github.com/fumiama/gofastTEA"
@@ -18,22 +19,36 @@ const (
CMDDAT CMDDAT
) )
var (
ErrMd5Mismatch = errors.New("cmdpacket.decrypt: md5 mismatch")
)
type CmdPacket struct { type CmdPacket struct {
cmd uint8
md5 [16]byte
t tea.TEA t tea.TEA
data []byte data []byte
rawCmdPacket
} }
type rawCmdPacket struct {
cmd uint8
len uint8
md5 [16]byte
raw [255]byte // raw will expand to len
}
//go:nosplit
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket { func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket {
return &CmdPacket{ return &CmdPacket{
cmd: cmd,
md5: md5.Sum(data),
t: *t, t: *t,
data: data, data: data,
rawCmdPacket: rawCmdPacket{
cmd: cmd,
md5: md5.Sum(data),
},
} }
} }
//go:nosplit
func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket { func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
if len(data) < 1+1+16 { if len(data) < 1+1+16 {
return nil return nil
@@ -41,34 +56,41 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
if len(data)-1-1-16 < int(data[1]) { if len(data)-1-1-16 < int(data[1]) {
return nil return nil
} }
var md5 [16]byte r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data)))
copy(md5[:], data[2:18]) c := &CmdPacket{
return &CmdPacket{ t: *t,
cmd: data[0], rawCmdPacket: rawCmdPacket{
md5: md5, cmd: r.cmd,
t: *t, len: r.len,
data: data[18 : data[1]+18], md5: r.md5,
},
} }
copy(c.raw[:], data[1+1+16:])
return c
} }
//go:nosplit
func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) { func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
setseq(&c.t, seq) setseq(&c.t, seq)
d := c.t.EncryptLittleEndian(c.data, sumtable) c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:]))
raw = append(raw, c.cmd, uint8(len(d))) (*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket)
raw = append(raw, c.md5[:]...) (*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len)
raw = append(raw, d...) (*slice)(unsafe.Pointer(&raw)).Cap = 1 + 1 + 16 + 255
return return
} }
func (c *CmdPacket) Decrypt(seq uint8) []byte { //go:nosplit
func (c *CmdPacket) Decrypt(seq uint8) error {
setseq(&c.t, seq) setseq(&c.t, seq)
d := c.t.DecryptLittleEndian(c.data, sumtable) d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable)
if d != nil && c.md5 == md5.Sum(d) { if d != nil && c.md5 == md5.Sum(d) {
return d c.data = d
return nil
} }
return nil return ErrMd5Mismatch
} }
//go:nosplit
func setseq(t *tea.TEA, seq uint8) { func setseq(t *tea.TEA, seq uint8) {
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq
} }

View File

@@ -18,7 +18,12 @@ func TestCmdPacket(t *testing.T) {
p := NewCmdPacket(CMDGET, []byte("test"), &tp) p := NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a := string(ack(t, conn, &tp).Decrypt(seq)) ackp := ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a := string(ackp.data)
t.Log(a) t.Log(a)
if a != "null" { if a != "null" {
t.Fail() t.Fail()
@@ -27,7 +32,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDSET, []byte("test"), &ts) p = NewCmdPacket(CMDSET, []byte("test"), &ts)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a) t.Log(a)
if a != "data" { if a != "data" {
t.Fail() t.Fail()
@@ -36,7 +47,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDDAT, []byte("测试"), &ts) p = NewCmdPacket(CMDDAT, []byte("测试"), &ts)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a) t.Log(a)
if a != "succ" { if a != "succ" {
t.Fail() t.Fail()
@@ -45,7 +62,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDGET, []byte("test"), &tp) p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a) t.Log(a)
if a != "测试" { if a != "测试" {
t.Fail() t.Fail()
@@ -54,7 +77,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDDEL, []byte("test"), &ts) p = NewCmdPacket(CMDDEL, []byte("test"), &ts)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a) t.Log(a)
if a != "succ" { if a != "succ" {
t.Fail() t.Fail()
@@ -63,7 +92,13 @@ func TestCmdPacket(t *testing.T) {
p = NewCmdPacket(CMDGET, []byte("test"), &tp) p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq)) conn.Write(p.Encrypt(seq))
seq++ seq++
a = string(ack(t, conn, &tp).Decrypt(seq))
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a) t.Log(a)
if a != "null" { if a != "null" {
t.Fail() t.Fail()

2
go.mod
View File

@@ -2,4 +2,4 @@ module github.com/fumiama/go-registry
go 1.17 go 1.17
require github.com/fumiama/gofastTEA v0.0.7 require github.com/fumiama/gofastTEA v0.0.9

4
go.sum
View File

@@ -1,2 +1,2 @@
github.com/fumiama/gofastTEA v0.0.7 h1:Dbce+22jNm+7jpmAeju0C+esIAAnIrq75f5TsCYprS8= github.com/fumiama/gofastTEA v0.0.9 h1:adaWz+014vMShnLUNWIHLBs0Yv6JNUohcaXZNtct5J0=
github.com/fumiama/gofastTEA v0.0.7/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= github.com/fumiama/gofastTEA v0.0.9/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk=

24
reg.go
View File

@@ -68,11 +68,11 @@ func (r *Regedit) Get(key string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
ackbytes := ack.Decrypt(r.seq) err = ack.Decrypt(r.seq)
if ackbytes == nil { if err != nil {
return "", errors.New("decrypt ack error") return "", errors.New("decrypt ack error")
} }
a := BytesToString(ackbytes) a := BytesToString(ack.data)
r.seq++ r.seq++
if a == "erro" { if a == "erro" {
return "", errors.New("server ack error") return "", errors.New("server ack error")
@@ -100,11 +100,11 @@ func (r *Regedit) Set(key, value string) error {
if err != nil { if err != nil {
return err return err
} }
ackbytes := ack.Decrypt(r.seq) err = ack.Decrypt(r.seq)
if ackbytes == nil { if err != nil {
return errors.New("decrypt ack error") return errors.New("decrypt ack error")
} }
a := BytesToString(ackbytes) a := BytesToString(ack.data)
r.seq++ r.seq++
if a == "erro" { if a == "erro" {
return errors.New("server ack error") return errors.New("server ack error")
@@ -119,11 +119,11 @@ func (r *Regedit) Set(key, value string) error {
if err != nil { if err != nil {
return err return err
} }
ackbytes = ack.Decrypt(r.seq) err = ack.Decrypt(r.seq)
if ackbytes == nil { if err != nil {
return errors.New("decrypt ack error") return errors.New("decrypt ack error")
} }
a = BytesToString(ackbytes) a = BytesToString(ack.data)
r.seq++ r.seq++
if a == "erro" { if a == "erro" {
return errors.New("server ack error") return errors.New("server ack error")
@@ -148,11 +148,11 @@ func (r *Regedit) Del(key string) error {
if err != nil { if err != nil {
return err return err
} }
ackbytes := ack.Decrypt(r.seq) err = ack.Decrypt(r.seq)
if ackbytes == nil { if err != nil {
return errors.New("decrypt ack error") return errors.New("decrypt ack error")
} }
a := BytesToString(ackbytes) a := BytesToString(ack.data)
r.seq++ r.seq++
if a == "erro" { if a == "erro" {
return errors.New("server ack error") return errors.New("server ack error")