diff --git a/jieba.go b/jieba.go index 7a5bad2..f5d6b92 100644 --- a/jieba.go +++ b/jieba.go @@ -9,7 +9,7 @@ import ( var ( Dictionary = "dict.txt" - TT *TopTrie + trie *Trie UserWordTagTab = make(map[string]string) ) @@ -70,39 +70,29 @@ func GetDAG(sentence string) map[int][]int { dag := make(map[int][]int) runes := []rune(sentence) n := len(runes) - p := TT.T - i, j := 0, 0 - var c rune - for { - if i >= n { - break - } - c = runes[j] - if _, ok := p.Nodes[c]; ok { - p = p.Nodes[c] - if p.IsLeaf { - if _, inDag := dag[i]; !inDag { - dag[i] = []int{j} - } else { - dag[i] = append(dag[i], j) - } + i := 0 + var frag string + for k := 0; k < n; k++ { + tmpList := make([]int, 0) + i = k + frag = string(runes[k]) + for { + if !trie.Nodes.Contains(frag) { + break } - j += 1 - if j >= n { - i += 1 - j = i - p = TT.T + if _, ok := trie.Freq[frag]; ok { + tmpList = append(tmpList, i) } - } else { - p = TT.T i += 1 - j = i + if i >= n { + break + } + frag = string(runes[k : i+1]) } - } - for i := 0; i < n; i++ { - if _, ok := dag[i]; !ok { - dag[i] = []int{i} + if len(tmpList) == 0 { + tmpList = append(tmpList, k) } + dag[k] = tmpList } return dag } @@ -122,10 +112,10 @@ func Calc(sentence string, dag map[int][]int, idx int) map[int]*Route { word = string(runes[idx : i+1]) } var route *Route - if _, ok := TT.Freq[word]; ok { - route = &Route{TT.Freq[word] + routes[i+1].Freq, i} + if _, ok := trie.Freq[word]; ok { + route = &Route{trie.Freq[word] + routes[i+1].Freq, i} } else { - route = &Route{TT.MinFreq + routes[i+1].Freq, i} + route = &Route{trie.MinFreq + routes[i+1].Freq, i} } candidates = append(candidates, route) } @@ -161,7 +151,7 @@ func cut_DAG(sentence string) []string { buf = make([]rune, 0) } else { bufString := string(buf) - if _, ok := TT.Freq[bufString]; !ok { + if _, ok := trie.Freq[bufString]; !ok { recognized := finalseg.Cut(bufString) for _, t := range recognized { result = append(result, t) @@ -184,7 +174,7 @@ func cut_DAG(sentence string) []string { result = append(result, string(buf)) } else { bufString := string(buf) - if _, ok := TT.Freq[bufString]; !ok { + if _, ok := trie.Freq[bufString]; !ok { recognized := finalseg.Cut(bufString) for _, t := range recognized { result = append(result, t) @@ -328,7 +318,7 @@ func CutForSearch(sentence string, hmm bool) []string { var gram2 string for i := 0; i < len(runes)-increment+1; i++ { gram2 = string(runes[i : i+increment]) - if _, ok := TT.Freq[gram2]; ok { + if _, ok := trie.Freq[gram2]; ok { result = append(result, gram2) } } @@ -340,6 +330,6 @@ func CutForSearch(sentence string, hmm bool) []string { } func SetDictionary(dict_path string) (err error) { - TT, err = newTopTrie(dict_path) + trie, err = newTrie(dict_path) return } diff --git a/tokenize.go b/tokenize.go index 235506f..3a1457b 100644 --- a/tokenize.go +++ b/tokenize.go @@ -24,7 +24,7 @@ func Tokenize(sentence string, mode string, HMM bool) []Token { if width > step { for i := 0; i < width-step+1; i++ { gram := string(runes[i : i+step]) - if _, ok := TT.Freq[gram]; ok { + if _, ok := trie.Freq[gram]; ok { tokens = append(tokens, Token{gram, start + i, start + i + step}) } } diff --git a/trie_node.go b/trie_node.go index 3726596..0a37de7 100644 --- a/trie_node.go +++ b/trie_node.go @@ -5,6 +5,7 @@ import ( "crypto/md5" "encoding/gob" "fmt" + mapset "github.com/deckarep/golang-set" "log" "math" "os" @@ -14,56 +15,47 @@ import ( ) type Trie struct { - Nodes map[rune]*Trie - IsLeaf bool -} - -func NewTrie() *Trie { - return &Trie{make(map[rune]*Trie), false} -} - -type TopTrie struct { - T *Trie + Nodes mapset.Set MinFreq float64 Total float64 Freq map[string]float64 } -func newTopTrie(filename string) (*TopTrie, error) { - var file_path string - var topTrie *TopTrie - if filepath.IsAbs(filename) { - file_path = filename +func newTrie(fileName string) (*Trie, error) { + var filePath string + var trie *Trie + if filepath.IsAbs(fileName) { + filePath = fileName } else { pwd, err := os.Getwd() if err != nil { return nil, err } - file_path = filepath.Clean(filepath.Join(pwd, filename)) + filePath = filepath.Clean(filepath.Join(pwd, fileName)) } - fi, err := os.Stat(file_path) + fi, err := os.Stat(filePath) if err != nil { return nil, err } - log.Printf("Building Trie..., from %s\n", file_path) - h := fmt.Sprintf("%x", md5.Sum([]byte(file_path))) - cache_file_name := fmt.Sprintf("jieba.%s.cache", h) - cache_path := filepath.Join(os.TempDir(), cache_file_name) + log.Printf("Building Trie..., from %s\n", filePath) + h := fmt.Sprintf("%x", md5.Sum([]byte(filePath))) + cacheFileName := fmt.Sprintf("jieba.%s.cache", h) + cacheFilePath := filepath.Join(os.TempDir(), cacheFileName) isDictCached := true - cache_fi, err := os.Stat(cache_path) + cacheFileInfo, err := os.Stat(cacheFilePath) if err != nil { isDictCached = false } if isDictCached { - isDictCached = cache_fi.ModTime().After(fi.ModTime()) + isDictCached = cacheFileInfo.ModTime().After(fi.ModTime()) } var cacheFile *os.File if isDictCached { - cacheFile, err = os.Open(cache_path) + cacheFile, err = os.Open(cacheFilePath) if err != nil { isDictCached = false } @@ -71,17 +63,19 @@ func newTopTrie(filename string) (*TopTrie, error) { } if isDictCached { dec := gob.NewDecoder(cacheFile) - err = dec.Decode(&topTrie) + err = dec.Decode(&trie) if err != nil { isDictCached = false } else { - log.Printf("loaded model from cache %s\n", cache_path) + log.Printf("loaded model from cache %s\n", cacheFilePath) } } if !isDictCached { - topTrie = &TopTrie{T: NewTrie(), MinFreq: 100.0, Total: 0.0, Freq: make(map[string]float64)} - file, openError := os.Open(file_path) + trie = &Trie{Nodes: mapset.NewSet(), MinFreq: 0.0, Total: 0.0, + Freq: make(map[string]float64)} + + file, openError := os.Open(filePath) if openError != nil { return nil, openError } @@ -93,55 +87,45 @@ func newTopTrie(filename string) (*TopTrie, error) { words := strings.Split(line, " ") word, freqStr := words[0], words[1] freq, _ := strconv.ParseFloat(freqStr, 64) - topTrie.Total += freq - topTrie.addWord(word, freq) + trie.addWord(word, freq) } if scanErr := scanner.Err(); scanErr != nil { return nil, scanErr } var val float64 - for key := range topTrie.Freq { - val = math.Log(topTrie.Freq[key] / topTrie.Total) - if val < topTrie.MinFreq { - topTrie.MinFreq = val + for key := range trie.Freq { + val = math.Log(trie.Freq[key] / trie.Total) + if val < trie.MinFreq { + trie.MinFreq = val } - topTrie.Freq[key] = val + trie.Freq[key] = val } - // dump topTrie - cacheFile, err = os.OpenFile(cache_path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + // dump trie + cacheFile, err = os.OpenFile(cacheFilePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { - return topTrie, err + return trie, err } defer cacheFile.Close() enc := gob.NewEncoder(cacheFile) - err := enc.Encode(topTrie) + err := enc.Encode(trie) if err != nil { - return topTrie, err + return trie, err } else { - log.Printf("dumped model from cache %s\n", cache_path) + log.Printf("dumped model from cache %s\n", cacheFilePath) } } - return topTrie, nil + return trie, nil } -func (tt *TopTrie) addWord(word string, freq float64) { - tt.Freq[word] = freq - var p *Trie +func (t *Trie) addWord(word string, freq float64) { + t.Freq[word] = freq + t.Total += freq runes := []rune(word) count := len(runes) - for index, key := range runes { - if index == 0 { - p = tt.T - } - if _, ok := p.Nodes[key]; !ok { - p.Nodes[key] = NewTrie() - } - if index == count-1 { - p.Nodes[key].IsLeaf = true - } - p = p.Nodes[key] + for i := 0; i < count; i++ { + t.Nodes.Add(string(runes[:i+1])) } } @@ -149,11 +133,11 @@ func addWord(word string, freq float64, tag string) { if len(tag) > 0 { UserWordTagTab[word] = strings.TrimSpace(tag) } - TT.addWord(word, freq) + trie.addWord(word, freq) } -func LoadUserDict(file_path string) error { - file, openError := os.Open(file_path) +func LoadUserDict(filePath string) error { + file, openError := os.Open(filePath) if openError != nil { return openError }