1
0
mirror of https://github.com/fumiama/jieba.git synced 2026-06-12 13:10:25 +08:00

优化 TextRanker float

This commit is contained in:
源文雨
2022-11-30 13:15:12 +08:00
parent a8d1e81f73
commit ae85ccb20a
3 changed files with 12 additions and 12 deletions

View File

@@ -20,7 +20,7 @@ func Example_extractTags() {
} }
func Example_textRank() { func Example_textRank() {
t, err := LoadDictionary("../dict.txt") t, err := NewTextRanker("../dict.txt")
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -16,7 +16,7 @@ var (
) )
type edge struct { type edge struct {
weight float64 weight uint64
start string start string
end string end string
} }
@@ -47,7 +47,7 @@ func newUndirectWeightedGraph() *undirectWeightedGraph {
} }
} }
func (u *undirectWeightedGraph) addEdge(start, end string, weight float64) { func (u *undirectWeightedGraph) addEdge(start, end string, weight uint64) {
if _, ok := u.graph[start]; !ok { if _, ok := u.graph[start]; !ok {
u.keys = append(u.keys, start) u.keys = append(u.keys, start)
u.graph[start] = edges{&edge{start: start, end: end, weight: weight}} u.graph[start] = edges{&edge{start: start, end: end, weight: weight}}
@@ -69,7 +69,7 @@ func (u *undirectWeightedGraph) rank() Segments {
} }
ws := make(map[string]float64, len(u.graph)*2) ws := make(map[string]float64, len(u.graph)*2)
outSum := make(map[string]float64, len(u.graph)*2) outSum := make(map[string]uint64, len(u.graph)*2)
wsdef := 1.0 wsdef := 1.0
if len(u.graph) > 0 { if len(u.graph) > 0 {
@@ -77,7 +77,7 @@ func (u *undirectWeightedGraph) rank() Segments {
} }
for n, out := range u.graph { for n, out := range u.graph {
ws[n] = wsdef ws[n] = wsdef
sum := 0.0 sum := uint64(0)
for _, e := range out { for _, e := range out {
sum += e.weight sum += e.weight
} }
@@ -89,7 +89,7 @@ func (u *undirectWeightedGraph) rank() Segments {
s := 0.0 s := 0.0
inedges := u.graph[n] inedges := u.graph[n]
for _, e := range inedges { for _, e := range inedges {
s += e.weight / outSum[e.end] * ws[e.end] s += float64(e.weight) * ws[e.end] / float64(outSum[e.end])
} }
ws[n] = (1 - dampingFactor) + dampingFactor*s ws[n] = (1 - dampingFactor) + dampingFactor*s
} }
@@ -121,7 +121,7 @@ func (t *TextRanker) TextRankWithPOS(sentence string, topK int, allowPOS []strin
posFilt[pos] = 1 posFilt[pos] = 1
} }
g := newUndirectWeightedGraph() g := newUndirectWeightedGraph()
cm := make(map[uint64]float64, 256) cm := make(map[uint64]uint64, 256)
hm := make(map[uint64][2]string, 256) hm := make(map[uint64][2]string, 256)
gethash := func(a, b string) uint64 { gethash := func(a, b string) uint64 {
h := crc64.New(crc64.MakeTable(crc64.ISO)) h := crc64.New(crc64.MakeTable(crc64.ISO))
@@ -143,10 +143,10 @@ func (t *TextRanker) TextRankWithPOS(sentence string, topK int, allowPOS []strin
} }
h := gethash(pairs[i].Text(), pairs[j].Text()) h := gethash(pairs[i].Text(), pairs[j].Text())
if _, ok := cm[h]; !ok { if _, ok := cm[h]; !ok {
cm[h] = 1.0 cm[h] = 1
hm[h] = [2]string{pairs[i].Text(), pairs[j].Text()} hm[h] = [2]string{pairs[i].Text(), pairs[j].Text()}
} else { } else {
cm[h] += 1.0 cm[h]++
} }
} }
} }
@@ -171,8 +171,8 @@ func (t *TextRanker) TextRank(sentence string, topK int) Segments {
// TextRanker is used to extract tags from sentence. // TextRanker is used to extract tags from sentence.
type TextRanker posseg.Segmenter type TextRanker posseg.Segmenter
// LoadDictionary reads a given file and create a new dictionary file for Textranker. // NewTextRanker reads a given file and create a new dictionary file for Textranker.
func LoadDictionary(fileName string) (TextRanker, error) { func NewTextRanker(fileName string) (TextRanker, error) {
seg := posseg.Segmenter{} seg := posseg.Segmenter{}
return TextRanker(seg), seg.LoadDictionary(fileName) return TextRanker(seg), seg.LoadDictionary(fileName)
} }

View File

@@ -23,7 +23,7 @@ var (
) )
func TestTextRank(t *testing.T) { func TestTextRank(t *testing.T) {
tr, err := LoadDictionary("../dict.txt") tr, err := NewTextRanker("../dict.txt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }