diff --git a/cmd.go b/cmd.go index b07db3e..aa34658 100644 --- a/cmd.go +++ b/cmd.go @@ -3,6 +3,7 @@ package registry import ( "crypto/md5" "errors" + "io" "unsafe" tea "github.com/fumiama/gofastTEA" @@ -24,7 +25,8 @@ var ( ) type CmdPacket struct { - t tea.TEA + io.ReaderFrom + t *tea.TEA data []byte rawCmdPacket } @@ -37,19 +39,17 @@ type rawCmdPacket struct { } //go:nosplit -func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) *CmdPacket { - return &CmdPacket{ - t: *t, - data: data, - rawCmdPacket: rawCmdPacket{ - cmd: cmd, - md5: md5.Sum(data), - }, - } +func NewCmdPacket(cmd uint8, data []byte, t *tea.TEA) (c *CmdPacket) { + c = pool.Get().(*CmdPacket) + c.t = t + c.data = data + c.cmd = cmd + c.md5 = md5.Sum(data) + return } //go:nosplit -func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket { +func ParseCmdPacket(data []byte, t *tea.TEA) (c *CmdPacket) { if len(data) < 1+1+16 { return nil } @@ -57,21 +57,81 @@ func ParseCmdPacket(data []byte, t *tea.TEA) *CmdPacket { return nil } r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&data))) - c := &CmdPacket{ - t: *t, - rawCmdPacket: rawCmdPacket{ - cmd: r.cmd, - len: r.len, - md5: r.md5, - }, - } + c = pool.Get().(*CmdPacket) + c.t = t + c.cmd = r.cmd + c.len = r.len + c.md5 = r.md5 copy(c.raw[:], data[1+1+16:]) return c } +//go:nosplit +func ReadCmdPacket(f io.Reader, t *tea.TEA) (c *CmdPacket, err error) { + c = pool.Get().(*CmdPacket) + buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket)) + _, err = io.ReadFull(f, buf[:1+1+16]) + if err != nil { + c.Put() + return nil, err + } + _, err = io.ReadFull(f, c.raw[:c.len]) + if err != nil { + c.Put() + return nil, err + } + return +} + +//go:nosplit +func (c *CmdPacket) Refresh(cmd uint8, data []byte, t *tea.TEA) { + c.t = t + c.data = data + c.cmd = cmd + c.md5 = md5.Sum(data) +} + +//go:nosplit +func (c *CmdPacket) ClearData() { + c.data = nil +} + +//go:nosplit +func (c *CmdPacket) ReadFrom(f io.Reader) (n int64, err error) { + buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket)) + cnt, err := io.ReadFull(f, buf[:1+1+16]) + if err != nil { + return int64(cnt), err + } + cnt, err = io.ReadFull(f, c.raw[:c.len]) + if err != nil { + return int64(cnt), err + } + return +} + +// Write should not be used due to the full-copy of buf +func (c *CmdPacket) Write(buf []byte) (n int, err error) { + oldlen := len(c.data) + c.data = append(c.data, buf...) + if len(c.data) < 1+1+16 { + return len(buf), nil + } + if len(c.data) < 1+1+16+int(c.len) { + return len(buf), nil + } + r := (*rawCmdPacket)(*(*unsafe.Pointer)(unsafe.Pointer(&c.data))) + c.cmd = r.cmd + c.len = r.len + c.md5 = r.md5 + copy(c.raw[:], r.raw[:c.len]) + c.data = nil + return 1 + 1 + 16 + int(c.len) - oldlen, nil +} + //go:nosplit func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) { - setseq(&c.t, seq) + setseq(c.t, seq) c.len = uint8(c.t.EncryptLittleEndianTo(c.data, sumtable, c.raw[:])) (*slice)(unsafe.Pointer(&raw)).Data = unsafe.Pointer(&c.rawCmdPacket) (*slice)(unsafe.Pointer(&raw)).Len = 1 + 1 + 16 + int(c.len) @@ -81,7 +141,7 @@ func (c *CmdPacket) Encrypt(seq uint8) (raw []byte) { //go:nosplit func (c *CmdPacket) Decrypt(seq uint8) error { - setseq(&c.t, seq) + setseq(c.t, seq) d := c.t.DecryptLittleEndian(c.raw[:c.len], sumtable) if d != nil && c.md5 == md5.Sum(d) { c.data = d @@ -90,6 +150,12 @@ func (c *CmdPacket) Decrypt(seq uint8) error { return ErrMd5Mismatch } +//go:nosplit +func (c *CmdPacket) Put() { + c.data = nil + pool.Put(c) +} + //go:nosplit func setseq(t *tea.TEA, seq uint8) { *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq diff --git a/cmd_test.go b/cmd_test.go index b33ce59..187b5d1 100644 --- a/cmd_test.go +++ b/cmd_test.go @@ -1,130 +1,40 @@ package registry import ( - "net" + "errors" "testing" - - tea "github.com/fumiama/gofastTEA" ) -func TestCmdPacket(t *testing.T) { - conn, err := net.Dial("tcp", "127.0.0.1:8888") +func TestRegedit(t *testing.T) { + reg := NewRegedit("127.0.0.1:8888", "testpwd", "testsps") + err := reg.Connect() if err != nil { t.Fatal(err) } - tp := tea.NewTeaCipherLittleEndian([]byte("testpwd\x00\x00\x00\x00\x00\x00\x00\x00\x00")) - ts := tea.NewTeaCipherLittleEndian([]byte("testsps\x00\x00\x00\x00\x00\x00\x00\x00\x00")) - var seq byte - p := NewCmdPacket(CMDGET, []byte("test"), &tp) - conn.Write(p.Encrypt(seq)) - seq++ - ackp := ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { + ret, err := reg.Get("test") + if err != nil && !errors.Is(err, ErrNoSuchKey) { t.Fatal(err) } - a := string(ackp.data) - t.Log(a) - if a != "null" { - t.Fail() - } - seq++ - p = NewCmdPacket(CMDSET, []byte("test"), &ts) - conn.Write(p.Encrypt(seq)) - seq++ - - ackp = ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { - t.Fatal(err) - } - a = string(ackp.data) - t.Log(a) - if a != "data" { - t.Fail() - } - seq++ - p = NewCmdPacket(CMDDAT, []byte("测试"), &ts) - conn.Write(p.Encrypt(seq)) - seq++ - - ackp = ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { - t.Fatal(err) - } - a = string(ackp.data) - t.Log(a) - if a != "succ" { - t.Fail() - } - seq++ - p = NewCmdPacket(CMDGET, []byte("test"), &tp) - conn.Write(p.Encrypt(seq)) - seq++ - - ackp = ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { - t.Fatal(err) - } - a = string(ackp.data) - t.Log(a) - if a != "测试" { - t.Fail() - } - seq++ - p = NewCmdPacket(CMDDEL, []byte("test"), &ts) - conn.Write(p.Encrypt(seq)) - seq++ - - ackp = ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { - t.Fatal(err) - } - a = string(ackp.data) - t.Log(a) - if a != "succ" { - t.Fail() - } - seq++ - p = NewCmdPacket(CMDGET, []byte("test"), &tp) - conn.Write(p.Encrypt(seq)) - seq++ - - ackp = ack(t, conn, &tp) - err = ackp.Decrypt(seq) - if err != nil { - t.Fatal(err) - } - a = string(ackp.data) - t.Log(a) - if a != "null" { - t.Fail() - } - seq++ -} - -func ack(t *testing.T, conn net.Conn, tp *tea.TEA) *CmdPacket { - var buf [1024]byte - n, err := conn.Read(buf[:]) - if err != nil { - t.Fatal(err) - } - for n < 1+1+16 { - m, err := conn.Read(buf[n:]) + t.Log(ret) + if ret != "" { + err = reg.Del("test") if err != nil { t.Fatal(err) } - n += m } - for n < 1+1+16+int(buf[1]) { - m, err := conn.Read(buf[n:]) - if err != nil { - t.Fatal(err) - } - n += m + err = reg.Set("test", "测试") + if err != nil { + t.Fatal(err) + } + ret, err = reg.Get("test") + if err != nil { + t.Fatal(err) + } + if ret != "测试" { + t.Fail() + } + err = reg.Close() + if err != nil { + t.Fatal(err) } - return ParseCmdPacket(buf[:], tp) } diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..cea4f45 --- /dev/null +++ b/pool.go @@ -0,0 +1,11 @@ +package registry + +import "sync" + +var pool sync.Pool + +func init() { + pool.New = func() interface{} { + return new(CmdPacket) + } +} diff --git a/reg.go b/reg.go index df92338..939b335 100644 --- a/reg.go +++ b/reg.go @@ -2,19 +2,30 @@ package registry import ( "errors" + "io" "net" "time" tea "github.com/fumiama/gofastTEA" ) +var ( + ErrGetKeyTooLong = errors.New("get key too long") + ErrDecAck = errors.New("decrypt ack error") + ErrInternalServer = errors.New("internal server error") + ErrPermissionDenied = errors.New("permission denied") + ErrSetKeyTooLong = errors.New("set key too long") + ErrSetValTooLong = errors.New("set val too long") + ErrUnknownAck = errors.New("unknown ack error") + ErrNoSuchKey = errors.New("no such key") +) + type Regedit struct { conn net.Conn addr string tp tea.TEA ts *tea.TEA seq byte - buf [255]byte } func NewRegedit(addr, pwd, sps string) *Regedit { @@ -53,137 +64,129 @@ func (r *Regedit) ConnectIn(wait time.Duration) (err error) { func (r *Regedit) Close() (err error) { p := NewCmdPacket(CMDEND, []byte("fill"), &r.tp) r.conn.Write(p.Encrypt(r.seq)) + p.Put() r.seq = 0 return r.conn.Close() } func (r *Regedit) Get(key string) (string, error) { if len(key) > 127 { - return "", errors.New("get key too long") + return "", ErrGetKeyTooLong } p := NewCmdPacket(CMDGET, StringToBytes(key), &r.tp) + defer p.Put() r.conn.Write(p.Encrypt(r.seq)) r.seq++ - ack, err := r.ack() + err := r.ack(p) if err != nil { return "", err } - err = ack.Decrypt(r.seq) + err = p.Decrypt(r.seq) if err != nil { - return "", errors.New("decrypt ack error") + return "", ErrDecAck } - a := BytesToString(ack.data) + a := BytesToString(p.data) r.seq++ if a == "erro" { - return "", errors.New("server ack error") + return "", ErrInternalServer } if a == "null" { - a = "" + return "", ErrNoSuchKey } return a, nil } func (r *Regedit) Set(key, value string) error { if r.ts == nil { - return errors.New("permission denied") + return ErrPermissionDenied } if len(key) > 127 { - return errors.New("set key too long") + return ErrSetKeyTooLong } if len(value) > 127 { - return errors.New("set val too long") + return ErrSetValTooLong } p := NewCmdPacket(CMDSET, StringToBytes(key), r.ts) + defer p.Put() r.conn.Write(p.Encrypt(r.seq)) r.seq++ - ack, err := r.ack() + ack := NewCmdPacket(CMDACK, nil, &r.tp) + defer ack.Put() + err := r.ack(ack) if err != nil { return err } err = ack.Decrypt(r.seq) if err != nil { - return errors.New("decrypt ack error") + return ErrDecAck } a := BytesToString(ack.data) r.seq++ if a == "erro" { - return errors.New("server ack error") + return ErrInternalServer } if a != "data" { - return errors.New("unknown ack error") + return ErrUnknownAck } - p = NewCmdPacket(CMDDAT, StringToBytes(value), r.ts) + p.Refresh(CMDDAT, StringToBytes(value), r.ts) r.conn.Write(p.Encrypt(r.seq)) r.seq++ - ack, err = r.ack() + err = r.ack(ack) if err != nil { return err } err = ack.Decrypt(r.seq) if err != nil { - return errors.New("decrypt ack error") + return ErrDecAck } a = BytesToString(ack.data) r.seq++ if a == "erro" { - return errors.New("server ack error") + return ErrInternalServer } if a != "succ" { - return errors.New("unknown ack error") + return ErrUnknownAck } return nil } func (r *Regedit) Del(key string) error { if r.ts == nil { - return errors.New("permission denied") + return ErrPermissionDenied } if len(key) > 127 { - return errors.New("get key too long") + return ErrGetKeyTooLong } p := NewCmdPacket(CMDDEL, StringToBytes(key), r.ts) + defer p.Put() r.conn.Write(p.Encrypt(r.seq)) r.seq++ - ack, err := r.ack() + ack := NewCmdPacket(CMDACK, nil, &r.tp) + defer ack.Put() + err := r.ack(ack) if err != nil { return err } err = ack.Decrypt(r.seq) if err != nil { - return errors.New("decrypt ack error") + return ErrDecAck } a := BytesToString(ack.data) r.seq++ if a == "erro" { - return errors.New("server ack error") + return ErrInternalServer } if a == "null" { - return errors.New("no such key") + return ErrNoSuchKey } if a != "succ" { - return errors.New("unknown ack error") + return ErrUnknownAck } return nil } -func (r *Regedit) ack() (*CmdPacket, error) { - n, err := r.conn.Read(r.buf[:]) - if err != nil { - return nil, err - } - for n < 1+1+16 { - m, err := r.conn.Read(r.buf[n:]) - if err != nil { - return nil, err - } - n += m - } - for n < 1+1+16+int(r.buf[1]) { - m, err := r.conn.Read(r.buf[n:]) - if err != nil { - return nil, err - } - n += m - } - return ParseCmdPacket(r.buf[:], &r.tp), nil +func (r *Regedit) ack(c *CmdPacket) error { + // c.ClearData() + _, err := io.Copy(c, r.conn) + return err }