From e3d1b92cc3a10dea35989ef30fa82b8453ac7c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 21 Sep 2025 01:00:49 +0800 Subject: [PATCH] optimize(chat): use more general batch item T --- chat/batch.go | 27 ++++++++++-------- chat/chat.go | 78 ++++++++++++++++++++++----------------------------- chat/item.go | 20 ------------- 3 files changed, 48 insertions(+), 77 deletions(-) delete mode 100644 chat/item.go diff --git a/chat/batch.go b/chat/batch.go index c1ec1ee..0ec9227 100644 --- a/chat/batch.go +++ b/chat/batch.go @@ -1,33 +1,36 @@ package chat -import "strings" +import ( + "fmt" + "strings" +) -type batch struct { - lst *Log - items []item +type batch[T fmt.Stringer] struct { + lst *Log[T] + items []T } -func (l *Log) newbatch(sz int) *batch { +func (l *Log[T]) newbatch(sz int) *batch[T] { if sz == 0 { panic("sz cannot be 0") } - return &batch{ + return &batch[T]{ lst: l, - items: make([]item, 0, sz), + items: make([]T, 0, sz), } } -func (cl *batch) String() string { +func (cl *batch[T]) String() string { sb := strings.Builder{} for _, item := range cl.items { - sb.WriteString(cl.lst.sep) - item.writeToBuilder(&sb, cl.lst.atprefix, cl.lst.namel, cl.lst.namer) + sb.WriteString(cl.lst.itemsep) + sb.WriteString(item.String()) } - return sb.String()[len(cl.lst.sep):] + return sb.String()[len(cl.lst.itemsep):] } // add without mutex -func (cl *batch) add(item item) *batch { +func (cl *batch[T]) add(item T) *batch[T] { v := cl.items defer func() { cl.items = v diff --git a/chat/chat.go b/chat/chat.go index b564196..5377f41 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -1,78 +1,66 @@ package chat import ( + "fmt" "sync" "github.com/fumiama/deepinfra" "github.com/fumiama/deepinfra/model" ) -type Log struct { - mu sync.RWMutex - cap int - sep string - defaultprompt string - namel, namer string - atprefix string - m map[int64][]*batch +type Log[T fmt.Stringer] struct { + mu sync.RWMutex + batchcap, itemscap int + itemsep string + defaultprompt string + m map[int64][]*batch[T] } -func NewLog(cap int, sep, defaultprompt, namel, namer, atprefix string) Log { - if cap < 2 { - panic("cap cannot < 2") +func NewLog[T fmt.Stringer](batchcap, itemscap int, itemsep, defaultprompt string) Log[T] { + if batchcap < 2 { + panic("batchcap cannot < 2") } - if cap%2 != 0 { - panic("cap % 2 must be 0") + if batchcap%2 != 0 { + panic("batchcap % 2 must be 0") } - return Log{ - cap: cap, - sep: sep, + if itemscap < 1 { + panic("itemscap cannot < 1") + } + return Log[T]{ + batchcap: batchcap, + itemscap: itemscap, + itemsep: itemsep, defaultprompt: defaultprompt, - namel: namel, - namer: namer, - atprefix: atprefix, - m: make(map[int64][]*batch, 64), + m: make(map[int64][]*batch[T], 64), } } -func (l *Log) Add(grp int64, usr, txt string, isbot, isatme bool) { +func (l *Log[T]) Add(grp int64, item T, isbot bool) { l.mu.Lock() defer l.mu.Unlock() msgs, ok := l.m[grp] if !ok { - msgs = make([]*batch, 1, l.cap) - msgs[0] = l.newbatch(l.cap).add(item{ - isatme: isatme, - usr: usr, txt: txt, - }) + msgs = make([]*batch[T], 1, l.batchcap) + msgs[0] = l.newbatch(l.itemscap).add(item) l.m[grp] = msgs return } isprevusr := len(msgs)%2 != 0 if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same - _ = msgs[len(msgs)-1].add(item{ - isatme: isatme, - usr: usr, txt: txt, - }) + _ = msgs[len(msgs)-1].add(item) return } if len(msgs) < cap(msgs) { - msgs = append(msgs, l.newbatch(l.cap).add(item{ - isatme: isatme, - usr: usr, txt: txt, - })) + msgs = append(msgs, l.newbatch(l.itemscap).add(item)) l.m[grp] = msgs return } copy(msgs, msgs[2:]) - msgs[len(msgs)-2] = l.newbatch(l.cap).add(item{ - isatme: isatme, - usr: usr, txt: txt, - }) + msgs[len(msgs)-2] = l.newbatch(l.itemscap).add(item) l.m[grp] = msgs[:len(msgs)-1] } -func (l *Log) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model { +func (l *Log[T]) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model { m := p if sysp != "" && !isusersystem { m.System(sysp) @@ -98,14 +86,14 @@ func (l *Log) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bo } // Modelize into any type from index and message -func Modelize[T any](l *Log, grp int64, f func(int, string) T) []T { +func Modelize[X any, T fmt.Stringer](l *Log[T], grp int64, f func(int, string) X) []X { l.mu.RLock() defer l.mu.RUnlock() sz := len(l.m[grp]) if sz == 0 { - return []T{f(0, l.defaultprompt)} + return []X{f(0, l.defaultprompt)} } - t := make([]T, sz) + t := make([]X, sz) for i, msg := range l.m[grp] { t[i] = f(i, msg.String()) } @@ -113,14 +101,14 @@ func Modelize[T any](l *Log, grp int64, f func(int, string) T) []T { } // Reset clears all conversation logs while preserving configuration -func (l *Log) Reset() { +func (l *Log[T]) Reset() { l.mu.Lock() defer l.mu.Unlock() - l.m = make(map[int64][]*batch, 64) + l.m = make(map[int64][]*batch[T], 64) } // ResetIn removes specified groups from the conversation logs -func (l *Log) ResetIn(grps ...int64) { +func (l *Log[T]) ResetIn(grps ...int64) { l.mu.Lock() defer l.mu.Unlock() for _, grp := range grps { diff --git a/chat/item.go b/chat/item.go deleted file mode 100644 index 7915eda..0000000 --- a/chat/item.go +++ /dev/null @@ -1,20 +0,0 @@ -package chat - -import "strings" - -type item struct { - isatme bool - usr, txt string -} - -func (item *item) writeToBuilder(sb *strings.Builder, atprefix, namel, namer string) { - if item.isatme { - sb.WriteString(atprefix) - } - if item.usr != "" { - sb.WriteString(namel) - sb.WriteString(item.usr) - sb.WriteString(namer) - } - sb.WriteString(item.txt) -}