1
0
mirror of https://github.com/fumiama/jieba.git synced 2026-06-05 00:32:51 +08:00

优化 posseg

This commit is contained in:
源文雨
2022-11-30 15:34:39 +08:00
parent 7c685f789e
commit a433e052c5

View File

@@ -32,33 +32,32 @@ func (pss probStates) Swap(i, j int) {
} }
func viterbi(obs []rune) []tag { func viterbi(obs []rune) []tag {
obsLength := len(obs) V := make([]map[uint16]float64, len(obs))
V := make([]map[uint16]float64, obsLength)
V[0] = make(map[uint16]float64) V[0] = make(map[uint16]float64)
memPath := make([]map[uint16]uint16, obsLength) memPath := make([]map[uint16]uint16, len(obs))
memPath[0] = make(map[uint16]uint16) memPath[0] = make(map[uint16]uint16)
ys := charStateTab.get(obs[0]) // default is all_states ys := charStateTab.get(obs[0]) // default is all_states
for _, y := range ys { for _, y := range ys {
V[0][y] = probEmit[y].get(obs[0]) + probStart[y] V[0][y] = probEmit[y].get(obs[0]) + probStart[y]
memPath[0][y] = 0 memPath[0][y] = 0
} }
for t := 1; t < obsLength; t++ { for t := 1; t < len(obs); t++ {
var prevStates []uint16 prevStates := make([]uint16, 0, 256)
for x := range memPath[t-1] { for x := range memPath[t-1] {
if len(probTrans[x]) > 0 { if len(probTrans[x]) > 0 {
prevStates = append(prevStates, x) prevStates = append(prevStates, x)
} }
} }
// use Go's map to implement Python's Set() // use Go's map to implement Python's Set()
prevStatesExpectNext := make(map[uint16]int) prevStatesExpectNext := make(map[uint16]struct{}, 256)
for _, x := range prevStates { for _, x := range prevStates {
for y := range probTrans[x] { for y := range probTrans[x] {
prevStatesExpectNext[y] = 1 prevStatesExpectNext[y] = struct{}{}
} }
} }
tmpObsStates := charStateTab.get(obs[t]) tmpObsStates := charStateTab.get(obs[t])
var obsStates []uint16 obsStates := make([]uint16, 0, 256)
for index := range tmpObsStates { for index := range tmpObsStates {
if _, ok := prevStatesExpectNext[tmpObsStates[index]]; ok { if _, ok := prevStatesExpectNext[tmpObsStates[index]]; ok {
obsStates = append(obsStates, tmpObsStates[index]) obsStates = append(obsStates, tmpObsStates[index])
@@ -79,7 +78,8 @@ func viterbi(obs []rune) []tag {
for i, y0 := range prevStates { for i, y0 := range prevStates {
ps = probState{ ps = probState{
prob: V[t-1][y0] + probTrans[y0].Get(y) + probEmit[y].get(obs[t]), prob: V[t-1][y0] + probTrans[y0].Get(y) + probEmit[y].get(obs[t]),
state: y0} state: y0,
}
if i == 0 || ps.prob > max.prob || (ps.prob == max.prob && ps.state > max.state) { if i == 0 || ps.prob > max.prob || (ps.prob == max.prob && ps.state > max.state) {
max = ps max = ps
} }
@@ -88,18 +88,18 @@ func viterbi(obs []rune) []tag {
memPath[t][y] = max.state memPath[t][y] = max.state
} }
} }
last := make(probStates, 0) last := make(probStates, len(memPath[len(memPath)-1]))
length := len(memPath) i := 0
vlength := len(V) for y := range memPath[len(memPath)-1] {
for y := range memPath[length-1] { last[i].prob = V[len(V)-1][y]
ps := probState{prob: V[vlength-1][y], state: y} last[i].state = y
last = append(last, ps) i++
} }
sort.Sort(sort.Reverse(last)) sort.Sort(sort.Reverse(last))
state := last[0].state state := last[0].state
route := make([]tag, len(obs)) route := make([]tag, len(obs))
for i := obsLength - 1; i >= 0; i-- { for i := len(obs) - 1; i >= 0; i-- {
route[i] = tag(state) route[i] = tag(state)
state = memPath[i][state] state = memPath[i][state]
} }