diff --git a/cmd.go b/cmd.go index aa34658..345a752 100644 --- a/cmd.go +++ b/cmd.go @@ -20,6 +20,15 @@ const ( CMDDAT ) +const ( + ACKNONE uint8 = iota<<4 + 3 + ACKSUCC + ACKDATA + ACKNULL + ACKNEQU + ACKERRO +) + var ( ErrMd5Mismatch = errors.New("cmdpacket.decrypt: md5 mismatch") ) @@ -91,13 +100,12 @@ func (c *CmdPacket) Refresh(cmd uint8, data []byte, t *tea.TEA) { c.md5 = md5.Sum(data) } -//go:nosplit -func (c *CmdPacket) ClearData() { - c.data = nil -} - //go:nosplit func (c *CmdPacket) ReadFrom(f io.Reader) (n int64, err error) { + if c.cmd > 0 { + err = io.EOF + return + } buf := (*[1 + 1 + 16 + 255]byte)(unsafe.Pointer(&c.rawCmdPacket)) cnt, err := io.ReadFull(f, buf[:1+1+16]) if err != nil { @@ -152,13 +160,14 @@ func (c *CmdPacket) Decrypt(seq uint8) error { //go:nosplit func (c *CmdPacket) Put() { + c.cmd = 0 c.data = nil pool.Put(c) } //go:nosplit func setseq(t *tea.TEA, seq uint8) { - *(*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(t)) + uintptr(15))) = seq + *(*uint8)(unsafe.Add(unsafe.Pointer(t), 15)) = seq } // TEA encoding sumtable diff --git a/go.mod b/go.mod index fef6818..2fe1502 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/fumiama/go-registry go 1.17 -require github.com/fumiama/gofastTEA v0.0.9 +require github.com/fumiama/gofastTEA v0.0.10 diff --git a/go.sum b/go.sum index 4a0f386..9882633 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/fumiama/gofastTEA v0.0.9 h1:adaWz+014vMShnLUNWIHLBs0Yv6JNUohcaXZNtct5J0= -github.com/fumiama/gofastTEA v0.0.9/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= +github.com/fumiama/gofastTEA v0.0.10 h1:JJJ+brWD4kie+mmK2TkspDXKzqq0IjXm89aGYfoGhhQ= +github.com/fumiama/gofastTEA v0.0.10/go.mod h1:RIdbYZyB4MbH6ZBlPymRaXn3cD6SedlCu5W/HHfMPBk= diff --git a/reg.go b/reg.go index d6941cf..8ad9fb7 100644 --- a/reg.go +++ b/reg.go @@ -107,10 +107,10 @@ func (r *Regedit) Get(key string) (string, error) { return "", ErrDecAck } a := string(p.data) - if a == "erro" { + if a == "erro" && p.cmd == ACKERRO { return "", ErrInternalServer } - if a == "null" { + if a == "null" && p.cmd == ACKNULL { return "", ErrNoSuchKey } return a, nil @@ -145,10 +145,10 @@ func (r *Regedit) Set(key, value string) error { return ErrDecAck } a := BytesToString(ack.data) - if a == "erro" { + if a == "erro" || ack.cmd == ACKERRO { return ErrInternalServer } - if a != "data" { + if a != "data" && ack.cmd != ACKDATA { return ErrUnknownAck } p.Refresh(CMDDAT, StringToBytes(value), r.ts) @@ -167,10 +167,10 @@ func (r *Regedit) Set(key, value string) error { return ErrDecAck } a = BytesToString(ack.data) - if a == "erro" { + if a == "erro" || ack.cmd == ACKERRO { return ErrInternalServer } - if a != "succ" { + if a != "succ" && ack.cmd != ACKSUCC { return ErrUnknownAck } return nil @@ -202,20 +202,20 @@ func (r *Regedit) Del(key string) error { return ErrDecAck } a := BytesToString(ack.data) - if a == "erro" { + if a == "erro" || ack.cmd == ACKERRO { return ErrInternalServer } - if a == "null" { + if a == "null" || ack.cmd == ACKNULL { return ErrNoSuchKey } - if a != "succ" { + if a != "succ" && ack.cmd != ACKSUCC { return ErrUnknownAck } return nil } func (r *Regedit) ack(c *CmdPacket) error { - // c.ClearData() + c.cmd = 0 _, err := io.Copy(c, r.conn) return err } diff --git a/reg_test.go b/reg_test.go index 581670c..17aff34 100644 --- a/reg_test.go +++ b/reg_test.go @@ -11,7 +11,7 @@ func TestReg(t *testing.T) { t.Fatal(err) } v, err := r.Get("test") - if err != nil { + if err != nil && err != ErrNoSuchKey { t.Fatal(err) } t.Log(v) @@ -29,10 +29,10 @@ func TestReg(t *testing.T) { t.Fatal(err) } v, err = r.Get("test") - if err != nil { + t.Log(v) + if err != ErrNoSuchKey { t.Fatal(err) } - t.Log(v) err = r.Close() if err != nil { t.Fatal(err)