1
0
mirror of https://github.com/fumiama/deepinfra.git synced 2026-06-05 00:32:46 +08:00

feat: add chat

This commit is contained in:
源文雨
2025-02-22 14:46:52 +09:00
parent 7c01751fe6
commit 2d458b6be9
3 changed files with 155 additions and 0 deletions

46
chat/batch.go Normal file
View File

@@ -0,0 +1,46 @@
package chat
import "strings"
type batch struct {
lst *Log
items []item
}
func (l *Log) newbatch(sz int) *batch {
if sz == 0 {
panic("sz cannot be 0")
}
return &batch{
lst: l,
items: make([]item, 0, sz),
}
}
func (cl *batch) String() string {
sb := strings.Builder{}
for _, item := range cl.items {
sb.WriteString(cl.lst.sep)
item.writeToBuilder(&sb, cl.lst.atprefix, cl.lst.namel, cl.lst.namer)
}
return sb.String()[len(cl.lst.sep):]
}
// add without mutex
func (cl *batch) add(item item) *batch {
v := cl.items
defer func() {
cl.items = v
}()
if cap(v) == 1 {
v[0] = item
return cl
}
if len(v) < cap(v) {
v = append(v, item)
return cl
}
copy(v, v[1:])
v[len(v)-1] = item
return cl
}

91
chat/chat.go Normal file
View File

@@ -0,0 +1,91 @@
package chat
import (
"sync"
"github.com/fumiama/deepinfra"
"github.com/fumiama/deepinfra/model"
)
type Log struct {
mu sync.RWMutex
cap int
sep string
defaultprompt string
namel, namer string
atprefix string
m map[int64][]*batch
}
func NewLog(cap int, sep, defaultprompt, namel, namer, atprefix string) Log {
if cap < 2 {
panic("cap cannot < 2")
}
if cap%2 != 0 {
panic("cap % 2 must be 0")
}
return Log{
cap: cap,
sep: sep,
defaultprompt: defaultprompt,
namel: namel,
namer: namer,
atprefix: atprefix,
m: make(map[int64][]*batch, 64),
}
}
func (l *Log) Add(grp int64, usr, txt string, isbot, isatme bool) {
l.mu.Lock()
defer l.mu.Unlock()
msgs, ok := l.m[grp]
if !ok {
msgs = make([]*batch, 1, l.cap)
msgs[0] = l.newbatch(l.cap).add(item{
isatme: isatme,
usr: usr, txt: txt,
})
l.m[grp] = msgs
return
}
isprevusr := len(msgs)%2 != 0
if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same
_ = msgs[len(msgs)-1].add(item{
isatme: isatme,
usr: usr, txt: txt,
})
return
}
if len(msgs) < cap(msgs) {
msgs = append(msgs, l.newbatch(l.cap).add(item{
isatme: isatme,
usr: usr, txt: txt,
}))
l.m[grp] = msgs
return
}
copy(msgs, msgs[2:])
msgs[len(msgs)-2] = l.newbatch(l.cap).add(item{
isatme: isatme,
usr: usr, txt: txt,
})
l.m[grp] = msgs[:len(msgs)-1]
}
func (l *Log) Modelize(p model.Protocol, grp int64, sysp string) deepinfra.Model {
m := p.System(sysp)
l.mu.RLock()
defer l.mu.RUnlock()
sz := len(l.m[grp])
if sz == 0 {
return m.User(l.defaultprompt)
}
for i, msg := range l.m[grp] {
if i%2 == 0 { // is user
_ = m.User(msg.String())
} else {
_ = m.Assistant(msg.String())
}
}
return m
}

18
chat/item.go Normal file
View File

@@ -0,0 +1,18 @@
package chat
import "strings"
type item struct {
isatme bool
usr, txt string
}
func (item *item) writeToBuilder(sb *strings.Builder, atprefix, namel, namer string) {
if item.isatme {
sb.WriteString(atprefix)
}
sb.WriteString(namel)
sb.WriteString(item.usr)
sb.WriteString(namer)
sb.WriteString(item.txt)
}