From 2d458b6be9c1dd4fbaee4afbab838919ae640a6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 22 Feb 2025 14:46:52 +0900 Subject: [PATCH] feat: add chat --- chat/batch.go | 46 ++++++++++++++++++++++++++ chat/chat.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++ chat/item.go | 18 ++++++++++ 3 files changed, 155 insertions(+) create mode 100644 chat/batch.go create mode 100644 chat/chat.go create mode 100644 chat/item.go diff --git a/chat/batch.go b/chat/batch.go new file mode 100644 index 0000000..c1ec1ee --- /dev/null +++ b/chat/batch.go @@ -0,0 +1,46 @@ +package chat + +import "strings" + +type batch struct { + lst *Log + items []item +} + +func (l *Log) newbatch(sz int) *batch { + if sz == 0 { + panic("sz cannot be 0") + } + return &batch{ + lst: l, + items: make([]item, 0, sz), + } +} + +func (cl *batch) String() string { + sb := strings.Builder{} + for _, item := range cl.items { + sb.WriteString(cl.lst.sep) + item.writeToBuilder(&sb, cl.lst.atprefix, cl.lst.namel, cl.lst.namer) + } + return sb.String()[len(cl.lst.sep):] +} + +// add without mutex +func (cl *batch) add(item item) *batch { + v := cl.items + defer func() { + cl.items = v + }() + if cap(v) == 1 { + v[0] = item + return cl + } + if len(v) < cap(v) { + v = append(v, item) + return cl + } + copy(v, v[1:]) + v[len(v)-1] = item + return cl +} diff --git a/chat/chat.go b/chat/chat.go new file mode 100644 index 0000000..c6785ad --- /dev/null +++ b/chat/chat.go @@ -0,0 +1,91 @@ +package chat + +import ( + "sync" + + "github.com/fumiama/deepinfra" + "github.com/fumiama/deepinfra/model" +) + +type Log struct { + mu sync.RWMutex + cap int + sep string + defaultprompt string + namel, namer string + atprefix string + m map[int64][]*batch +} + +func NewLog(cap int, sep, defaultprompt, namel, namer, atprefix string) Log { + if cap < 2 { + panic("cap cannot < 2") + } + if cap%2 != 0 { + panic("cap % 2 must be 0") + } + return Log{ + cap: cap, + sep: sep, + defaultprompt: defaultprompt, + namel: namel, + namer: namer, + atprefix: atprefix, + m: make(map[int64][]*batch, 64), + } +} + +func (l *Log) Add(grp int64, usr, txt string, isbot, isatme bool) { + l.mu.Lock() + defer l.mu.Unlock() + msgs, ok := l.m[grp] + if !ok { + msgs = make([]*batch, 1, l.cap) + msgs[0] = l.newbatch(l.cap).add(item{ + isatme: isatme, + usr: usr, txt: txt, + }) + l.m[grp] = msgs + return + } + isprevusr := len(msgs)%2 != 0 + if (isprevusr && !isbot) || (!isprevusr && isbot) { // is same + _ = msgs[len(msgs)-1].add(item{ + isatme: isatme, + usr: usr, txt: txt, + }) + return + } + if len(msgs) < cap(msgs) { + msgs = append(msgs, l.newbatch(l.cap).add(item{ + isatme: isatme, + usr: usr, txt: txt, + })) + l.m[grp] = msgs + return + } + copy(msgs, msgs[2:]) + msgs[len(msgs)-2] = l.newbatch(l.cap).add(item{ + isatme: isatme, + usr: usr, txt: txt, + }) + l.m[grp] = msgs[:len(msgs)-1] +} + +func (l *Log) Modelize(p model.Protocol, grp int64, sysp string) deepinfra.Model { + m := p.System(sysp) + l.mu.RLock() + defer l.mu.RUnlock() + sz := len(l.m[grp]) + if sz == 0 { + return m.User(l.defaultprompt) + } + for i, msg := range l.m[grp] { + if i%2 == 0 { // is user + _ = m.User(msg.String()) + } else { + _ = m.Assistant(msg.String()) + } + } + return m +} diff --git a/chat/item.go b/chat/item.go new file mode 100644 index 0000000..34e1692 --- /dev/null +++ b/chat/item.go @@ -0,0 +1,18 @@ +package chat + +import "strings" + +type item struct { + isatme bool + usr, txt string +} + +func (item *item) writeToBuilder(sb *strings.Builder, atprefix, namel, namer string) { + if item.isatme { + sb.WriteString(atprefix) + } + sb.WriteString(namel) + sb.WriteString(item.usr) + sb.WriteString(namer) + sb.WriteString(item.txt) +}