1
0
mirror of https://github.com/fumiama/go-registry.git synced 2026-06-04 23:40:27 +08:00

优化速度

This commit is contained in:
fumiama
2022-02-15 13:26:16 +08:00
parent e7c05d8bc8
commit b6a85b186c
4 changed files with 171 additions and 181 deletions

108
cmd.go
View File

@@ -3,6 +3,7 @@ package registry
import (
"crypto/md5"
"errors"
"io"
"unsafe"
tea "github.com/fumiama/gofastTEA"
@@ -24,7 +25,8 @@ var (
)
type CmdPacket struct {
t tea.TEA
io.ReaderFrom
t *tea.TEA
data []byte
rawCmdPacket
}
@@ -37,19 +39,17 @@ type rawCmdPacket struct {
}
//go:nosplit
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket {
return &CmdPacket{
t: *t,
data: data,
rawCmdPacket: rawCmdPacket{
cmd: cmd,
md5: md5.Sum(data),
},
}
func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) (c *CmdPacket) {
c = pool.Get().(*CmdPacket)
c.t = t
c.data = data
c.cmd = cmd
c.md5 = md5.Sum(data)
return
}
//go:nosplit
func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
func ParseCmdPacket(data []byte, t *tea.TEA) (c *CmdPacket) {
if len(data) < 1+1+16 {
return nil
}
@@ -57,21 +57,81 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket {
return nil
}
r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data)))
c := &CmdPacket{
t: *t,
rawCmdPacket: rawCmdPacket{
cmd: r.cmd,
len: r.len,
md5: r.md5,
},
}
c = pool.Get().(*CmdPacket)
c.t = t
c.cmd = r.cmd
c.len = r.len
c.md5 = r.md5
copy(c.raw[:], data[1+1+16:])
return c
}
//go:nosplit
func ReadCmdPacket(f io.Reader, t *tea.TEA) (c *CmdPacket, err error) {
c = pool.Get().(*CmdPacket)
buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket))
_, err = io.ReadFull(f, buf[:1+1+16])
if err != nil {
c.Put()
return nil, err
}
_, err = io.ReadFull(f, c.raw[:c.len])
if err != nil {
c.Put()
return nil, err
}
return
}
//go:nosplit
func (c *CmdPacket) Refresh(cmd uint8, data []byte, t *tea.TEA) {
c.t = t
c.data = data
c.cmd = cmd
c.md5 = md5.Sum(data)
}
//go:nosplit
func (c *CmdPacket) ClearData() {
c.data = nil
}
//go:nosplit
func (c *CmdPacket) ReadFrom(f io.Reader) (n int64, err error) {
buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket))
cnt, err := io.ReadFull(f, buf[:1+1+16])
if err != nil {
return int64(cnt), err
}
cnt, err = io.ReadFull(f, c.raw[:c.len])
if err != nil {
return int64(cnt), err
}
return
}
// Write should not be used due to the full-copy of buf
func (c *CmdPacket) Write(buf []byte) (n int, err error) {
oldlen := len(c.data)
c.data = append(c.data, buf...)
if len(c.data) < 1+1+16 {
return len(buf), nil
}
if len(c.data) < 1+1+16+int(c.len) {
return len(buf), nil
}
r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&c.data)))
c.cmd = r.cmd
c.len = r.len
c.md5 = r.md5
copy(c.raw[:], r.raw[:c.len])
c.data = nil
return 1 + 1 + 16 + int(c.len) - oldlen, nil
}
//go:nosplit
func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
setseq(&c.t, seq)
setseq(c.t, seq)
c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:]))
(*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket)
(*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len)
@@ -81,7 +141,7 @@ func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) {
//go:nosplit
func (c *CmdPacket) Decrypt(seq uint8) error {
setseq(&c.t, seq)
setseq(c.t, seq)
d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable)
if d != nil && c.md5 == md5.Sum(d) {
c.data = d
@@ -90,6 +150,12 @@ func (c *CmdPacket) Decrypt(seq uint8) error {
return ErrMd5Mismatch
}
//go:nosplit
func (c *CmdPacket) Put() {
c.data = nil
pool.Put(c)
}
//go:nosplit
func setseq(t *tea.TEA, seq uint8) {
*(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq

View File

@@ -1,130 +1,40 @@
package registry
import (
"net"
"errors"
"testing"
tea "github.com/fumiama/gofastTEA"
)
func TestCmdPacket(t *testing.T) {
conn, err := net.Dial("tcp", "127.0.0.1:8888")
func TestRegedit(t *testing.T) {
reg := NewRegedit("127.0.0.1:8888", "testpwd", "testsps")
err := reg.Connect()
if err != nil {
t.Fatal(err)
}
tp := tea.NewTeaCipherLittleEndian([]byte("testpwd\x00\x00\x00\x00\x00\x00\x00\x00\x00"))
ts := tea.NewTeaCipherLittleEndian([]byte("testsps\x00\x00\x00\x00\x00\x00\x00\x00\x00"))
var seq byte
p := NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
ackp := ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
ret, err := reg.Get("test")
if err != nil && !errors.Is(err, ErrNoSuchKey) {
t.Fatal(err)
}
a := string(ackp.data)
t.Log(a)
if a != "null" {
t.Fail()
}
seq++
p = NewCmdPacket(CMDSET, []byte("test"), &ts)
conn.Write(p.Encrypt(seq))
seq++
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "data" {
t.Fail()
}
seq++
p = NewCmdPacket(CMDDAT, []byte("测试"), &ts)
conn.Write(p.Encrypt(seq))
seq++
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "succ" {
t.Fail()
}
seq++
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "测试" {
t.Fail()
}
seq++
p = NewCmdPacket(CMDDEL, []byte("test"), &ts)
conn.Write(p.Encrypt(seq))
seq++
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "succ" {
t.Fail()
}
seq++
p = NewCmdPacket(CMDGET, []byte("test"), &tp)
conn.Write(p.Encrypt(seq))
seq++
ackp = ack(t, conn, &tp)
err = ackp.Decrypt(seq)
if err != nil {
t.Fatal(err)
}
a = string(ackp.data)
t.Log(a)
if a != "null" {
t.Fail()
}
seq++
}
func ack(t *testing.T, conn net.Conn, tp *tea.TEA) *CmdPacket {
var buf [1024]byte
n, err := conn.Read(buf[:])
if err != nil {
t.Fatal(err)
}
for n < 1+1+16 {
m, err := conn.Read(buf[n:])
t.Log(ret)
if ret != "" {
err = reg.Del("test")
if err != nil {
t.Fatal(err)
}
n += m
}
for n < 1+1+16+int(buf[1]) {
m, err := conn.Read(buf[n:])
if err != nil {
t.Fatal(err)
}
n += m
err = reg.Set("test", "测试")
if err != nil {
t.Fatal(err)
}
ret, err = reg.Get("test")
if err != nil {
t.Fatal(err)
}
if ret != "测试" {
t.Fail()
}
err = reg.Close()
if err != nil {
t.Fatal(err)
}
return ParseCmdPacket(buf[:], tp)
}

11
pool.go Normal file
View File

@@ -0,0 +1,11 @@
package registry
import "sync"
var pool sync.Pool
func init() {
pool.New = func() interface{} {
return new(CmdPacket)
}
}

97
reg.go
View File

@@ -2,19 +2,30 @@ package registry
import (
"errors"
"io"
"net"
"time"
tea "github.com/fumiama/gofastTEA"
)
var (
ErrGetKeyTooLong = errors.New("get key too long")
ErrDecAck = errors.New("decrypt ack error")
ErrInternalServer = errors.New("internal server error")
ErrPermissionDenied = errors.New("permission denied")
ErrSetKeyTooLong = errors.New("set key too long")
ErrSetValTooLong = errors.New("set val too long")
ErrUnknownAck = errors.New("unknown ack error")
ErrNoSuchKey = errors.New("no such key")
)
type Regedit struct {
conn net.Conn
addr string
tp tea.TEA
ts *tea.TEA
seq byte
buf [255]byte
}
func NewRegedit(addr, pwd, sps string) *Regedit {
@@ -53,137 +64,129 @@ func (r *Regedit) ConnectIn(wait time.Duration) (err error) {
func (r *Regedit) Close() (err error) {
p := NewCmdPacket(CMDEND, []byte("fill"), &r.tp)
r.conn.Write(p.Encrypt(r.seq))
p.Put()
r.seq = 0
return r.conn.Close()
}
func (r *Regedit) Get(key string) (string, error) {
if len(key) > 127 {
return "", errors.New("get key too long")
return "", ErrGetKeyTooLong
}
p := NewCmdPacket(CMDGET, StringToBytes(key), &r.tp)
defer p.Put()
r.conn.Write(p.Encrypt(r.seq))
r.seq++
ack, err := r.ack()
err := r.ack(p)
if err != nil {
return "", err
}
err = ack.Decrypt(r.seq)
err = p.Decrypt(r.seq)
if err != nil {
return "", errors.New("decrypt ack error")
return "", ErrDecAck
}
a := BytesToString(ack.data)
a := BytesToString(p.data)
r.seq++
if a == "erro" {
return "", errors.New("server ack error")
return "", ErrInternalServer
}
if a == "null" {
a = ""
return "", ErrNoSuchKey
}
return a, nil
}
func (r *Regedit) Set(key, value string) error {
if r.ts == nil {
return errors.New("permission denied")
return ErrPermissionDenied
}
if len(key) > 127 {
return errors.New("set key too long")
return ErrSetKeyTooLong
}
if len(value) > 127 {
return errors.New("set val too long")
return ErrSetValTooLong
}
p := NewCmdPacket(CMDSET, StringToBytes(key), r.ts)
defer p.Put()
r.conn.Write(p.Encrypt(r.seq))
r.seq++
ack, err := r.ack()
ack := NewCmdPacket(CMDACK, nil, &r.tp)
defer ack.Put()
err := r.ack(ack)
if err != nil {
return err
}
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
return ErrDecAck
}
a := BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")
return ErrInternalServer
}
if a != "data" {
return errors.New("unknown ack error")
return ErrUnknownAck
}
p = NewCmdPacket(CMDDAT, StringToBytes(value), r.ts)
p.Refresh(CMDDAT, StringToBytes(value), r.ts)
r.conn.Write(p.Encrypt(r.seq))
r.seq++
ack, err = r.ack()
err = r.ack(ack)
if err != nil {
return err
}
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
return ErrDecAck
}
a = BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")
return ErrInternalServer
}
if a != "succ" {
return errors.New("unknown ack error")
return ErrUnknownAck
}
return nil
}
func (r *Regedit) Del(key string) error {
if r.ts == nil {
return errors.New("permission denied")
return ErrPermissionDenied
}
if len(key) > 127 {
return errors.New("get key too long")
return ErrGetKeyTooLong
}
p := NewCmdPacket(CMDDEL, StringToBytes(key), r.ts)
defer p.Put()
r.conn.Write(p.Encrypt(r.seq))
r.seq++
ack, err := r.ack()
ack := NewCmdPacket(CMDACK, nil, &r.tp)
defer ack.Put()
err := r.ack(ack)
if err != nil {
return err
}
err = ack.Decrypt(r.seq)
if err != nil {
return errors.New("decrypt ack error")
return ErrDecAck
}
a := BytesToString(ack.data)
r.seq++
if a == "erro" {
return errors.New("server ack error")
return ErrInternalServer
}
if a == "null" {
return errors.New("no such key")
return ErrNoSuchKey
}
if a != "succ" {
return errors.New("unknown ack error")
return ErrUnknownAck
}
return nil
}
func (r *Regedit) ack() (*CmdPacket, error) {
n, err := r.conn.Read(r.buf[:])
if err != nil {
return nil, err
}
for n < 1+1+16 {
m, err := r.conn.Read(r.buf[n:])
if err != nil {
return nil, err
}
n += m
}
for n < 1+1+16+int(r.buf[1]) {
m, err := r.conn.Read(r.buf[n:])
if err != nil {
return nil, err
}
n += m
}
return ParseCmdPacket(r.buf[:], &r.tp), nil
func (r *Regedit) ack(c *CmdPacket) error {
// c.ClearData()
_, err := io.Copy(c, r.conn)
return err
}