diff --git a/trie_node.go b/trie_node.go index 0a37de7..2890db2 100644 --- a/trie_node.go +++ b/trie_node.go @@ -2,6 +2,7 @@ package jiebago import ( "bufio" + "bytes" "crypto/md5" "encoding/gob" "fmt" @@ -21,6 +22,53 @@ type Trie struct { Freq map[string]float64 } +func (t Trie) MarshalBinary() ([]byte, error) { + var b bytes.Buffer + enc := gob.NewEncoder(&b) + err := enc.Encode(t.Nodes.ToSlice()) + log.Println(err) + if err != nil { + return nil, err + } + err = enc.Encode(t.MinFreq) + if err != nil { + return nil, err + } + err = enc.Encode(t.Total) + if err != nil { + return nil, err + } + err = enc.Encode(t.Freq) + if err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func (t *Trie) UnmarshalBinary(data []byte) error { + b := bytes.NewBuffer(data) + dec := gob.NewDecoder(b) + var nodes []interface{} + err := dec.Decode(&nodes) + if err != nil { + return err + } + t.Nodes = mapset.NewSetFromSlice(nodes) + err = dec.Decode(&t.MinFreq) + if err != nil { + return err + } + err = dec.Decode(&t.Total) + if err != nil { + return err + } + err = dec.Decode(&t.Freq) + if err != nil { + return err + } + return nil +} + func newTrie(fileName string) (*Trie, error) { var filePath string var trie *Trie