mirror of
https://github.com/fumiama/deepinfra.git
synced 2026-06-05 00:32:46 +08:00
feat: add chat
This commit is contained in:
46
chat/batch.go
Normal file
46
chat/batch.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package chat
|
||||
|
||||
import "strings"
|
||||
|
||||
type batch struct {
|
||||
lst *Log
|
||||
items []item
|
||||
}
|
||||
|
||||
func (l *Log) newbatch(sz int) *batch {
|
||||
if sz == 0 {
|
||||
panic("sz cannot be 0")
|
||||
}
|
||||
return &batch{
|
||||
lst: l,
|
||||
items: make([]item, 0, sz),
|
||||
}
|
||||
}
|
||||
|
||||
func (cl *batch) String() string {
|
||||
sb := strings.Builder{}
|
||||
for _, item := range cl.items {
|
||||
sb.WriteString(cl.lst.sep)
|
||||
item.writeToBuilder(&sb, cl.lst.atprefix, cl.lst.namel, cl.lst.namer)
|
||||
}
|
||||
return sb.String()[len(cl.lst.sep):]
|
||||
}
|
||||
|
||||
// add without mutex
|
||||
func (cl *batch) add(item item) *batch {
|
||||
v := cl.items
|
||||
defer func() {
|
||||
cl.items = v
|
||||
}()
|
||||
if cap(v) == 1 {
|
||||
v[0] = item
|
||||
return cl
|
||||
}
|
||||
if len(v) < cap(v) {
|
||||
v = append(v, item)
|
||||
return cl
|
||||
}
|
||||
copy(v, v[1:])
|
||||
v[len(v)-1] = item
|
||||
return cl
|
||||
}
|
||||
91
chat/chat.go
Normal file
91
chat/chat.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package chat
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/fumiama/deepinfra"
|
||||
"github.com/fumiama/deepinfra/model"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
mu sync.RWMutex
|
||||
cap int
|
||||
sep string
|
||||
defaultprompt string
|
||||
namel, namer string
|
||||
atprefix string
|
||||
m map[int64][]*batch
|
||||
}
|
||||
|
||||
func NewLog(cap int, sep, defaultprompt, namel, namer, atprefix string) Log {
|
||||
if cap < 2 {
|
||||
panic("cap cannot < 2")
|
||||
}
|
||||
if cap%2 != 0 {
|
||||
panic("cap % 2 must be 0")
|
||||
}
|
||||
return Log{
|
||||
cap: cap,
|
||||
sep: sep,
|
||||
defaultprompt: defaultprompt,
|
||||
namel: namel,
|
||||
namer: namer,
|
||||
atprefix: atprefix,
|
||||
m: make(map[int64][]*batch, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Log) Add(grp int64, usr, txt string, isbot, isatme bool) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
msgs, ok := l.m[grp]
|
||||
if !ok {
|
||||
msgs = make([]*batch, 1, l.cap)
|
||||
msgs[0] = l.newbatch(l.cap).add(item{
|
||||
isatme: isatme,
|
||||
usr: usr, txt: txt,
|
||||
})
|
||||
l.m[grp] = msgs
|
||||
return
|
||||
}
|
||||
isprevusr := len(msgs)%2 != 0
|
||||
if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same
|
||||
_ = msgs[len(msgs)-1].add(item{
|
||||
isatme: isatme,
|
||||
usr: usr, txt: txt,
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(msgs) < cap(msgs) {
|
||||
msgs = append(msgs, l.newbatch(l.cap).add(item{
|
||||
isatme: isatme,
|
||||
usr: usr, txt: txt,
|
||||
}))
|
||||
l.m[grp] = msgs
|
||||
return
|
||||
}
|
||||
copy(msgs, msgs[2:])
|
||||
msgs[len(msgs)-2] = l.newbatch(l.cap).add(item{
|
||||
isatme: isatme,
|
||||
usr: usr, txt: txt,
|
||||
})
|
||||
l.m[grp] = msgs[:len(msgs)-1]
|
||||
}
|
||||
|
||||
func (l *Log) Modelize(p model.Protocol, grp int64, sysp string) deepinfra.Model {
|
||||
m := p.System(sysp)
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
sz := len(l.m[grp])
|
||||
if sz == 0 {
|
||||
return m.User(l.defaultprompt)
|
||||
}
|
||||
for i, msg := range l.m[grp] {
|
||||
if i%2 == 0 { // is user
|
||||
_ = m.User(msg.String())
|
||||
} else {
|
||||
_ = m.Assistant(msg.String())
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
18
chat/item.go
Normal file
18
chat/item.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package chat
|
||||
|
||||
import "strings"
|
||||
|
||||
type item struct {
|
||||
isatme bool
|
||||
usr, txt string
|
||||
}
|
||||
|
||||
func (item *item) writeToBuilder(sb *strings.Builder, atprefix, namel, namer string) {
|
||||
if item.isatme {
|
||||
sb.WriteString(atprefix)
|
||||
}
|
||||
sb.WriteString(namel)
|
||||
sb.WriteString(item.usr)
|
||||
sb.WriteString(namer)
|
||||
sb.WriteString(item.txt)
|
||||
}
|
||||
Reference in New Issue
Block a user