mirror of
https://github.com/fumiama/go-registry.git
synced 2026-06-24 04:30:28 +08:00
优化速度
This commit is contained in:
60
cmd.go
60
cmd.go
@@ -2,6 +2,7 @@ package registry
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
|
"errors"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
tea "github.com/fumiama/gofastTEA"
|
tea "github.com/fumiama/gofastTEA"
|
||||||
@@ -18,22 +19,36 @@ const (
|
|||||||
CMDDAT
|
CMDDAT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMd5Mismatch = errors.New("cmdpacket.decrypt: md5 mismatch")
|
||||||
|
)
|
||||||
|
|
||||||
type CmdPacket struct {
|
type CmdPacket struct {
|
||||||
cmd uint8
|
|
||||||
md5 [16]byte
|
|
||||||
t tea.TEA
|
t tea.TEA
|
||||||
data []byte
|
data []byte
|
||||||
|
rawCmdPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type rawCmdPacket struct {
|
||||||
|
cmd uint8
|
||||||
|
len uint8
|
||||||
|
md5 [16]byte
|
||||||
|
raw [255]byte // raw will expand to len
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:nosplit
|
||||||
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket {
|
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket {
|
||||||
return &CmdPacket{
|
return &CmdPacket{
|
||||||
cmd: cmd,
|
|
||||||
md5: md5.Sum(data),
|
|
||||||
t: *t,
|
t: *t,
|
||||||
data: data,
|
data: data,
|
||||||
|
rawCmdPacket: rawCmdPacket{
|
||||||
|
cmd: cmd,
|
||||||
|
md5: md5.Sum(data),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:nosplit
|
||||||
func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
|
func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
|
||||||
if len(data) < 1+1+16 {
|
if len(data) < 1+1+16 {
|
||||||
return nil
|
return nil
|
||||||
@@ -41,34 +56,41 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
|
|||||||
if len(data)-1-1-16 < int(data[1]) {
|
if len(data)-1-1-16 < int(data[1]) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var md5 [16]byte
|
r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data)))
|
||||||
copy(md5[:], data[2:18])
|
c := &CmdPacket{
|
||||||
return &CmdPacket{
|
t: *t,
|
||||||
cmd: data[0],
|
rawCmdPacket: rawCmdPacket{
|
||||||
md5: md5,
|
cmd: r.cmd,
|
||||||
t: *t,
|
len: r.len,
|
||||||
data: data[18 : data[1]+18],
|
md5: r.md5,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
copy(c.raw[:], data[1+1+16:])
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:nosplit
|
||||||
func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
|
func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
|
||||||
setseq(&c.t, seq)
|
setseq(&c.t, seq)
|
||||||
d := c.t.EncryptLittleEndian(c.data, sumtable)
|
c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:]))
|
||||||
raw = append(raw, c.cmd, uint8(len(d)))
|
(*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket)
|
||||||
raw = append(raw, c.md5[:]...)
|
(*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len)
|
||||||
raw = append(raw, d...)
|
(*slice)(unsafe.Pointer(&raw)).Cap = 1 + 1 + 16 + 255
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CmdPacket) Decrypt(seq uint8) []byte {
|
//go:nosplit
|
||||||
|
func (c *CmdPacket) Decrypt(seq uint8) error {
|
||||||
setseq(&c.t, seq)
|
setseq(&c.t, seq)
|
||||||
d := c.t.DecryptLittleEndian(c.data, sumtable)
|
d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable)
|
||||||
if d != nil && c.md5 == md5.Sum(d) {
|
if d != nil && c.md5 == md5.Sum(d) {
|
||||||
return d
|
c.data = d
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return ErrMd5Mismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:nosplit
|
||||||
func setseq(t *tea.TEA, seq uint8) {
|
func setseq(t *tea.TEA, seq uint8) {
|
||||||
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq
|
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq
|
||||||
}
|
}
|
||||||
|
|||||||
47
cmd_test.go
47
cmd_test.go
@@ -18,7 +18,12 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p := NewCmdPacket(CMDGET, []byte("test"), &tp)
|
p := NewCmdPacket(CMDGET, []byte("test"), &tp)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a := string(ack(t, conn, &tp).Decrypt(seq))
|
ackp := ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a := string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "null" {
|
if a != "null" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
@@ -27,7 +32,13 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p = NewCmdPacket(CMDSET, []byte("test"), &ts)
|
p = NewCmdPacket(CMDSET, []byte("test"), &ts)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a = string(ack(t, conn, &tp).Decrypt(seq))
|
|
||||||
|
ackp = ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a = string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "data" {
|
if a != "data" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
@@ -36,7 +47,13 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p = NewCmdPacket(CMDDAT, []byte("测试"), &ts)
|
p = NewCmdPacket(CMDDAT, []byte("测试"), &ts)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a = string(ack(t, conn, &tp).Decrypt(seq))
|
|
||||||
|
ackp = ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a = string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "succ" {
|
if a != "succ" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
@@ -45,7 +62,13 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
|
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a = string(ack(t, conn, &tp).Decrypt(seq))
|
|
||||||
|
ackp = ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a = string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "测试" {
|
if a != "测试" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
@@ -54,7 +77,13 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p = NewCmdPacket(CMDDEL, []byte("test"), &ts)
|
p = NewCmdPacket(CMDDEL, []byte("test"), &ts)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a = string(ack(t, conn, &tp).Decrypt(seq))
|
|
||||||
|
ackp = ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a = string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "succ" {
|
if a != "succ" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
@@ -63,7 +92,13 @@ func TestCmdPacket(t *testing.T) {
|
|||||||
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
|
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
|
||||||
conn.Write(p.Encrypt(seq))
|
conn.Write(p.Encrypt(seq))
|
||||||
seq++
|
seq++
|
||||||
a = string(ack(t, conn, &tp).Decrypt(seq))
|
|
||||||
|
ackp = ack(t, conn, &tp)
|
||||||
|
err = ackp.Decrypt(seq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a = string(ackp.data)
|
||||||
t.Log(a)
|
t.Log(a)
|
||||||
if a != "null" {
|
if a != "null" {
|
||||||
t.Fail()
|
t.Fail()
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -2,4 +2,4 @@ module github.com/fumiama/go-registry
|
|||||||
|
|
||||||
go 1.17
|
go 1.17
|
||||||
|
|
||||||
require github.com/fumiama/gofastTEA v0.0.7
|
require github.com/fumiama/gofastTEA v0.0.9
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -1,2 +1,2 @@
|
|||||||
github.com/fumiama/gofastTEA v0.0.7 h1:Dbce+22jNm+7jpmAeju0C+esIAAnIrq75f5TsCYprS8=
|
github.com/fumiama/gofastTEA v0.0.9 h1:adaWz+014vMShnLUNWIHLBs0Yv6JNUohcaXZNtct5J0=
|
||||||
github.com/fumiama/gofastTEA v0.0.7/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk=
|
github.com/fumiama/gofastTEA v0.0.9/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk=
|
||||||
|
|||||||
24
reg.go
24
reg.go
@@ -68,11 +68,11 @@ func (r *Regedit) Get(key string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
ackbytes := ack.Decrypt(r.seq)
|
err = ack.Decrypt(r.seq)
|
||||||
if ackbytes == nil {
|
if err != nil {
|
||||||
return "", errors.New("decrypt ack error")
|
return "", errors.New("decrypt ack error")
|
||||||
}
|
}
|
||||||
a := BytesToString(ackbytes)
|
a := BytesToString(ack.data)
|
||||||
r.seq++
|
r.seq++
|
||||||
if a == "erro" {
|
if a == "erro" {
|
||||||
return "", errors.New("server ack error")
|
return "", errors.New("server ack error")
|
||||||
@@ -100,11 +100,11 @@ func (r *Regedit) Set(key, value string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ackbytes := ack.Decrypt(r.seq)
|
err = ack.Decrypt(r.seq)
|
||||||
if ackbytes == nil {
|
if err != nil {
|
||||||
return errors.New("decrypt ack error")
|
return errors.New("decrypt ack error")
|
||||||
}
|
}
|
||||||
a := BytesToString(ackbytes)
|
a := BytesToString(ack.data)
|
||||||
r.seq++
|
r.seq++
|
||||||
if a == "erro" {
|
if a == "erro" {
|
||||||
return errors.New("server ack error")
|
return errors.New("server ack error")
|
||||||
@@ -119,11 +119,11 @@ func (r *Regedit) Set(key, value string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ackbytes = ack.Decrypt(r.seq)
|
err = ack.Decrypt(r.seq)
|
||||||
if ackbytes == nil {
|
if err != nil {
|
||||||
return errors.New("decrypt ack error")
|
return errors.New("decrypt ack error")
|
||||||
}
|
}
|
||||||
a = BytesToString(ackbytes)
|
a = BytesToString(ack.data)
|
||||||
r.seq++
|
r.seq++
|
||||||
if a == "erro" {
|
if a == "erro" {
|
||||||
return errors.New("server ack error")
|
return errors.New("server ack error")
|
||||||
@@ -148,11 +148,11 @@ func (r *Regedit) Del(key string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ackbytes := ack.Decrypt(r.seq)
|
err = ack.Decrypt(r.seq)
|
||||||
if ackbytes == nil {
|
if err != nil {
|
||||||
return errors.New("decrypt ack error")
|
return errors.New("decrypt ack error")
|
||||||
}
|
}
|
||||||
a := BytesToString(ackbytes)
|
a := BytesToString(ack.data)
|
||||||
r.seq++
|
r.seq++
|
||||||
if a == "erro" {
|
if a == "erro" {
|
||||||
return errors.New("server ack error")
|
return errors.New("server ack error")
|
||||||
|
|||||||
Reference in New Issue
Block a user