mirror of
https://github.com/fumiama/deepinfra.git
synced 2026-06-05 00:32:46 +08:00
optimize(chat): use more general batch item T
This commit is contained in:
@@ -1,33 +1,36 @@
|
|||||||
package chat
|
package chat
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type batch struct {
|
type batch[T fmt.Stringer] struct {
|
||||||
lst *Log
|
lst *Log[T]
|
||||||
items []item
|
items []T
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Log) newbatch(sz int) *batch {
|
func (l *Log[T]) newbatch(sz int) *batch[T] {
|
||||||
if sz == 0 {
|
if sz == 0 {
|
||||||
panic("sz cannot be 0")
|
panic("sz cannot be 0")
|
||||||
}
|
}
|
||||||
return &batch{
|
return &batch[T]{
|
||||||
lst: l,
|
lst: l,
|
||||||
items: make([]item, 0, sz),
|
items: make([]T, 0, sz),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cl *batch) String() string {
|
func (cl *batch[T]) String() string {
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
for _, item := range cl.items {
|
for _, item := range cl.items {
|
||||||
sb.WriteString(cl.lst.sep)
|
sb.WriteString(cl.lst.itemsep)
|
||||||
item.writeToBuilder(&sb, cl.lst.atprefix, cl.lst.namel, cl.lst.namer)
|
sb.WriteString(item.String())
|
||||||
}
|
}
|
||||||
return sb.String()[len(cl.lst.sep):]
|
return sb.String()[len(cl.lst.itemsep):]
|
||||||
}
|
}
|
||||||
|
|
||||||
// add without mutex
|
// add without mutex
|
||||||
func (cl *batch) add(item item) *batch {
|
func (cl *batch[T]) add(item T) *batch[T] {
|
||||||
v := cl.items
|
v := cl.items
|
||||||
defer func() {
|
defer func() {
|
||||||
cl.items = v
|
cl.items = v
|
||||||
|
|||||||
74
chat/chat.go
74
chat/chat.go
@@ -1,78 +1,66 @@
|
|||||||
package chat
|
package chat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/fumiama/deepinfra"
|
"github.com/fumiama/deepinfra"
|
||||||
"github.com/fumiama/deepinfra/model"
|
"github.com/fumiama/deepinfra/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log[T fmt.Stringer] struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
cap int
|
batchcap, itemscap int
|
||||||
sep string
|
itemsep string
|
||||||
defaultprompt string
|
defaultprompt string
|
||||||
namel, namer string
|
m map[int64][]*batch[T]
|
||||||
atprefix string
|
|
||||||
m map[int64][]*batch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLog(cap int, sep, defaultprompt, namel, namer, atprefix string) Log {
|
func NewLog[T fmt.Stringer](batchcap, itemscap int, itemsep, defaultprompt string) Log[T] {
|
||||||
if cap < 2 {
|
if batchcap < 2 {
|
||||||
panic("cap cannot < 2")
|
panic("batchcap cannot < 2")
|
||||||
}
|
}
|
||||||
if cap%2 != 0 {
|
if batchcap%2 != 0 {
|
||||||
panic("cap % 2 must be 0")
|
panic("batchcap % 2 must be 0")
|
||||||
}
|
}
|
||||||
return Log{
|
if itemscap < 1 {
|
||||||
cap: cap,
|
panic("itemscap cannot < 1")
|
||||||
sep: sep,
|
}
|
||||||
|
return Log[T]{
|
||||||
|
batchcap: batchcap,
|
||||||
|
itemscap: itemscap,
|
||||||
|
itemsep: itemsep,
|
||||||
defaultprompt: defaultprompt,
|
defaultprompt: defaultprompt,
|
||||||
namel: namel,
|
m: make(map[int64][]*batch[T], 64),
|
||||||
namer: namer,
|
|
||||||
atprefix: atprefix,
|
|
||||||
m: make(map[int64][]*batch, 64),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Log) Add(grp int64, usr, txt string, isbot, isatme bool) {
|
func (l *Log[T]) Add(grp int64, item T, isbot bool) {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
msgs, ok := l.m[grp]
|
msgs, ok := l.m[grp]
|
||||||
if !ok {
|
if !ok {
|
||||||
msgs = make([]*batch, 1, l.cap)
|
msgs = make([]*batch[T], 1, l.batchcap)
|
||||||
msgs[0] = l.newbatch(l.cap).add(item{
|
msgs[0] = l.newbatch(l.itemscap).add(item)
|
||||||
isatme: isatme,
|
|
||||||
usr: usr, txt: txt,
|
|
||||||
})
|
|
||||||
l.m[grp] = msgs
|
l.m[grp] = msgs
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
isprevusr := len(msgs)%2 != 0
|
isprevusr := len(msgs)%2 != 0
|
||||||
if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same
|
if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same
|
||||||
_ = msgs[len(msgs)-1].add(item{
|
_ = msgs[len(msgs)-1].add(item)
|
||||||
isatme: isatme,
|
|
||||||
usr: usr, txt: txt,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(msgs) < cap(msgs) {
|
if len(msgs) < cap(msgs) {
|
||||||
msgs = append(msgs, l.newbatch(l.cap).add(item{
|
msgs = append(msgs, l.newbatch(l.itemscap).add(item))
|
||||||
isatme: isatme,
|
|
||||||
usr: usr, txt: txt,
|
|
||||||
}))
|
|
||||||
l.m[grp] = msgs
|
l.m[grp] = msgs
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
copy(msgs, msgs[2:])
|
copy(msgs, msgs[2:])
|
||||||
msgs[len(msgs)-2] = l.newbatch(l.cap).add(item{
|
msgs[len(msgs)-2] = l.newbatch(l.itemscap).add(item)
|
||||||
isatme: isatme,
|
|
||||||
usr: usr, txt: txt,
|
|
||||||
})
|
|
||||||
l.m[grp] = msgs[:len(msgs)-1]
|
l.m[grp] = msgs[:len(msgs)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Log) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model {
|
func (l *Log[T]) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model {
|
||||||
m := p
|
m := p
|
||||||
if sysp != "" && !isusersystem {
|
if sysp != "" && !isusersystem {
|
||||||
m.System(sysp)
|
m.System(sysp)
|
||||||
@@ -98,14 +86,14 @@ func (l *Log) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Modelize into any type from index and message
|
// Modelize into any type from index and message
|
||||||
func Modelize[T any](l *Log, grp int64, f func(int, string) T) []T {
|
func Modelize[X any, T fmt.Stringer](l *Log[T], grp int64, f func(int, string) X) []X {
|
||||||
l.mu.RLock()
|
l.mu.RLock()
|
||||||
defer l.mu.RUnlock()
|
defer l.mu.RUnlock()
|
||||||
sz := len(l.m[grp])
|
sz := len(l.m[grp])
|
||||||
if sz == 0 {
|
if sz == 0 {
|
||||||
return []T{f(0, l.defaultprompt)}
|
return []X{f(0, l.defaultprompt)}
|
||||||
}
|
}
|
||||||
t := make([]T, sz)
|
t := make([]X, sz)
|
||||||
for i, msg := range l.m[grp] {
|
for i, msg := range l.m[grp] {
|
||||||
t[i] = f(i, msg.String())
|
t[i] = f(i, msg.String())
|
||||||
}
|
}
|
||||||
@@ -113,14 +101,14 @@ func Modelize[T any](l *Log, grp int64, f func(int, string) T) []T {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset clears all conversation logs while preserving configuration
|
// Reset clears all conversation logs while preserving configuration
|
||||||
func (l *Log) Reset() {
|
func (l *Log[T]) Reset() {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
l.m = make(map[int64][]*batch, 64)
|
l.m = make(map[int64][]*batch[T], 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetIn removes specified groups from the conversation logs
|
// ResetIn removes specified groups from the conversation logs
|
||||||
func (l *Log) ResetIn(grps ...int64) {
|
func (l *Log[T]) ResetIn(grps ...int64) {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
for _, grp := range grps {
|
for _, grp := range grps {
|
||||||
|
|||||||
20
chat/item.go
20
chat/item.go
@@ -1,20 +0,0 @@
|
|||||||
package chat
|
|
||||||
|
|
||||||
import "strings"
|
|
||||||
|
|
||||||
type item struct {
|
|
||||||
isatme bool
|
|
||||||
usr, txt string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (item *item) writeToBuilder(sb *strings.Builder, atprefix, namel, namer string) {
|
|
||||||
if item.isatme {
|
|
||||||
sb.WriteString(atprefix)
|
|
||||||
}
|
|
||||||
if item.usr != "" {
|
|
||||||
sb.WriteString(namel)
|
|
||||||
sb.WriteString(item.usr)
|
|
||||||
sb.WriteString(namer)
|
|
||||||
}
|
|
||||||
sb.WriteString(item.txt)
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user