diff --git a/reg.go b/reg.go index 035774b..607c85a 100644 --- a/reg.go +++ b/reg.go @@ -4,6 +4,7 @@ import ( "errors" "io" "net" + "sync" "time" tea "github.com/fumiama/gofastTEA" @@ -21,11 +22,13 @@ var ( ) type Regedit struct { - conn net.Conn - addr string - tp tea.TEA - ts *tea.TEA - seq byte + sync.Mutex + conn net.Conn + addr string + tp tea.TEA + ts *tea.TEA + seq byte + isopen bool } func NewRegedit(addr, pwd, sps string) *Regedit { @@ -52,21 +55,37 @@ func NewRegReader(addr, pwd string) *Regedit { } func (r *Regedit) Connect() (err error) { + r.Lock() r.conn, err = net.Dial("tcp", r.addr) + if err != nil { + r.isopen = true + } + r.Unlock() return } func (r *Regedit) ConnectIn(wait time.Duration) (err error) { + r.Lock() r.conn, err = net.DialTimeout("tcp", r.addr, wait) + if err != nil { + r.isopen = true + } + r.Unlock() return } func (r *Regedit) Close() (err error) { - p := NewCmdPacket(CMDEND, []byte("fill"), &r.tp) - r.conn.Write(p.Encrypt(r.seq)) - p.Put() - r.seq = 0 - return r.conn.Close() + r.Lock() + defer r.Unlock() + if r.isopen { + p := NewCmdPacket(CMDEND, []byte("fill"), &r.tp) + r.conn.Write(p.Encrypt(r.seq)) + p.Put() + r.seq = 0 + r.isopen = false + return r.conn.Close() + } + return } func (r *Regedit) Get(key string) (string, error) { @@ -75,18 +94,21 @@ func (r *Regedit) Get(key string) (string, error) { } p := NewCmdPacket(CMDGET, StringToBytes(key), &r.tp) defer p.Put() + r.Lock() r.conn.Write(p.Encrypt(r.seq)) r.seq++ err := r.ack(p) if err != nil { + r.Unlock() return "", err } err = p.Decrypt(r.seq) + r.seq++ + r.Unlock() if err != nil { return "", ErrDecAck } a := string(p.data) - r.seq++ if a == "erro" { return "", ErrInternalServer } @@ -108,20 +130,23 @@ func (r *Regedit) Set(key, value string) error { } p := NewCmdPacket(CMDSET, StringToBytes(key), r.ts) defer p.Put() + r.Lock() r.conn.Write(p.Encrypt(r.seq)) r.seq++ ack := NewCmdPacket(CMDACK, nil, &r.tp) defer ack.Put() err := r.ack(ack) if err != nil { + r.Unlock() return err } err = ack.Decrypt(r.seq) + r.seq++ + r.Unlock() if err != nil { return ErrDecAck } a := BytesToString(ack.data) - r.seq++ if a == "erro" { return ErrInternalServer } @@ -129,18 +154,21 @@ func (r *Regedit) Set(key, value string) error { return ErrUnknownAck } p.Refresh(CMDDAT, StringToBytes(value), r.ts) + r.Lock() r.conn.Write(p.Encrypt(r.seq)) r.seq++ err = r.ack(ack) if err != nil { + r.Unlock() return err } err = ack.Decrypt(r.seq) + r.seq++ + r.Unlock() if err != nil { return ErrDecAck } a = BytesToString(ack.data) - r.seq++ if a == "erro" { return ErrInternalServer } @@ -159,20 +187,23 @@ func (r *Regedit) Del(key string) error { } p := NewCmdPacket(CMDDEL, StringToBytes(key), r.ts) defer p.Put() + r.Lock() r.conn.Write(p.Encrypt(r.seq)) r.seq++ ack := NewCmdPacket(CMDACK, nil, &r.tp) defer ack.Put() err := r.ack(ack) if err != nil { + r.Unlock() return err } err = ack.Decrypt(r.seq) + r.seq++ + r.Unlock() if err != nil { return ErrDecAck } a := BytesToString(ack.data) - r.seq++ if a == "erro" { return ErrInternalServer }