diff --git a/api.go b/api.go
index 62290ea..cf12ef0 100644
--- a/api.go
+++ b/api.go
@@ -52,5 +52,5 @@ func (api *API) Request(model Model) (string, error) {
if err != nil {
return "", err
}
- return model.Output(), nil
+ return model.Output().String(), nil
}
diff --git a/chat/chat.go b/chat/chat.go
index 5377f41..1280679 100644
--- a/chat/chat.go
+++ b/chat/chat.go
@@ -63,24 +63,24 @@ func (l *Log[T]) Add(grp int64, item T, isbot bool) {
func (l *Log[T]) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model {
m := p
if sysp != "" && !isusersystem {
- m.System(sysp)
+ m.System(model.NewContentText(sysp))
}
l.mu.RLock()
defer l.mu.RUnlock()
sz := len(l.m[grp])
if sz == 0 {
- return m.User(l.defaultprompt)
+ return m.User(model.NewContentText(l.defaultprompt))
}
for i, msg := range l.m[grp] {
if i%2 == 0 { // is user
if i == 0 && isusersystem {
- _ = m.User(sysp + "\n\n" + msg.String())
+ _ = m.User(model.NewContentText(sysp + "\n\n" + msg.String()))
continue
}
- _ = m.User(msg.String())
+ _ = m.User(model.NewContentText(msg.String()))
continue
}
- _ = m.Assistant(msg.String())
+ _ = m.Assistant(model.NewContentText(msg.String()))
}
return m
}
diff --git a/go.mod b/go.mod
index fe83726..ed7e172 100644
--- a/go.mod
+++ b/go.mod
@@ -1,3 +1,21 @@
module github.com/fumiama/deepinfra
go 1.20
+
+require (
+ github.com/FloatTech/gg v1.1.3
+ github.com/FloatTech/imgfactory v0.2.1
+ github.com/fumiama/imgsz v0.0.4
+)
+
+require (
+ github.com/disintegration/imaging v1.6.2 // indirect
+ github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 // indirect
+ github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
+ golang.org/x/text v0.15.0 // indirect
+)
+
+require (
+ github.com/pkg/errors v0.9.1
+ golang.org/x/image v0.16.0 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..370c3aa
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,20 @@
+github.com/FloatTech/gg v1.1.3 h1:+GlL02lTKsxJQr4WCuNwVxC1/eBZrCvypCIBtxuOFb4=
+github.com/FloatTech/gg v1.1.3/go.mod h1:/9oLP54CMfq4r+71XL26uaFTJ1uL1boAyX67680/1HE=
+github.com/FloatTech/imgfactory v0.2.1 h1:XoVwy0Xu0AvTRtzlhv5teZcuZlAcHrYjeQ8MynJ/zlk=
+github.com/FloatTech/imgfactory v0.2.1/go.mod h1:QBJKHbzpE+x/9Wn7mXebWap/K/xUJSjgiaelAElwU9Q=
+github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
+github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
+github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 h1:BBade+JlV/f7JstZ4pitd4tHhpN+w+6I+LyOS7B4fyU=
+github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4/go.mod h1:H7chHJglrhPPzetLdzBleF8d22WYOv7UM/lEKYiwlKM=
+github.com/fumiama/imgsz v0.0.4 h1:Lsasu2hdSSFS+vnD+nvR1UkiRMK7hcpyYCC0FzgSMFI=
+github.com/fumiama/imgsz v0.0.4/go.mod h1:bISOQVTlw9sRytPwe8ir7tAaEmyz9hSNj9n8mXMBG0E=
+github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
+github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
+golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
+golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
diff --git a/model/api.go b/model/api.go
index a73af94..d784e8f 100644
--- a/model/api.go
+++ b/model/api.go
@@ -12,8 +12,8 @@ type Inputer interface {
}
type Outputer interface {
- Output() string
- OutputRaw() string
+ Output() Contents
+ OutputRaw() Contents
}
type Requester interface {
@@ -22,9 +22,9 @@ type Requester interface {
}
type MessageBuilder[T any] interface {
- System(prompt string) T
- User(prompt string) T
- Assistant(prompt string) T
+ System(prompt ...Content) T
+ User(prompt ...Content) T
+ Assistant(prompt ...Content) T
}
type Protocol interface {
diff --git a/model/content.go b/model/content.go
new file mode 100644
index 0000000..a1d2485
--- /dev/null
+++ b/model/content.go
@@ -0,0 +1,93 @@
+package model
+
+import (
+ "bytes"
+ "encoding/base64"
+ "io"
+ "mime"
+ "strings"
+
+ "github.com/fumiama/imgsz"
+ "github.com/pkg/errors"
+)
+
+var (
+ ErrUnsupportedImageExtension = errors.New("unsupported image extension")
+)
+
+type ContentType string
+
+const (
+ ContentTypeText ContentType = "text"
+ ContentTypeImageURL ContentType = "image_url"
+)
+
+type Contents []Content
+
+func (cs Contents) String() string {
+ switch len(cs) {
+ case 0:
+ return ""
+ case 1:
+ if cs[0].Type == ContentTypeText {
+ return cs[0].Text
+ }
+ fallthrough
+ default:
+ sb := strings.Builder{}
+ for _, c := range cs {
+ switch c.Type {
+ case ContentTypeText:
+ sb.WriteString("")
+ sb.WriteString(c.Text)
+ sb.WriteString("")
+ case ContentTypeImageURL:
+ if c.ImageURL != nil {
+ sb.WriteString("")
+ sb.WriteString(c.ImageURL.URL)
+ sb.WriteString("")
+ }
+ default:
+ panic("unsupported ContentType " + c.Type)
+ }
+ }
+ return sb.String()
+ }
+}
+
+type Content struct {
+ Type ContentType `json:"type"`
+ Text string `json:"text,omitempty"`
+ ImageURL *ContentImageURL `json:"image_url,omitempty"`
+}
+
+func NewContentText(txt string) Content {
+ return Content{Type: ContentTypeText, Text: txt}
+}
+
+type ContentImageURL struct {
+ URL string `json:"url"`
+}
+
+func NewContentImageURL(u string) Content {
+ return Content{Type: ContentTypeImageURL, ImageURL: &ContentImageURL{URL: u}}
+}
+
+func NewContentImageDataBase64URL(data []byte) (string, error) {
+ _, ext, err := imgsz.DecodeSize(bytes.NewReader(data))
+ if err != nil {
+ return "", err
+ }
+ typ := mime.TypeByExtension("." + ext)
+ if len(typ) == 0 {
+ return "", errors.Wrap(ErrUnsupportedImageExtension, ext)
+ }
+ sb := strings.Builder{}
+ sb.WriteString("data:")
+ sb.WriteString(typ)
+ sb.WriteString(";base64,")
+ enc := base64.NewEncoder(base64.StdEncoding, &sb)
+ _, _ = io.Copy(enc, bytes.NewReader(data))
+ _ = enc.Close()
+ return sb.String(), nil
+}
diff --git a/model/content_test.go b/model/content_test.go
new file mode 100644
index 0000000..10a278b
--- /dev/null
+++ b/model/content_test.go
@@ -0,0 +1,33 @@
+package model
+
+import (
+ "encoding/base64"
+ "strings"
+ "testing"
+)
+
+const contentTextSmallImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
+
+func TestContentImageDataBytes(t *testing.T) {
+ // 从 smallimg 中提取 base64 数据
+ parts := strings.Split(contentTextSmallImage, ",")
+ if len(parts) != 2 {
+ t.Fatal("Invalid data URL format")
+ }
+
+ // 解码 base64 数据
+ data, err := base64.StdEncoding.DecodeString(parts[1])
+ if err != nil {
+ t.Fatalf("Failed to decode base64: %v", err)
+ }
+
+ s, err := NewContentImageDataBase64URL(data)
+ if err != nil {
+ t.Fatalf("NewContentImageDataBytes failed: %v", err)
+ }
+
+ // 比较结果
+ if s != contentTextSmallImage {
+ t.Errorf("Expected %s, got %s", contentTextSmallImage, s)
+ }
+}
diff --git a/model/genai.go b/model/genai.go
index b2b66b9..cc26e4a 100644
--- a/model/genai.go
+++ b/model/genai.go
@@ -13,16 +13,22 @@ const (
ModelGemini15Flash = "models/gemini-1.5-flash"
)
-type Text struct {
- Text string `json:"text"`
+type GenAIInlineData struct {
+ MimeType string `json:"mime_type"`
+ Data string `json:"data"` // Data is base64 repr
}
-type Content struct {
- Parts []Text `json:"parts"`
- Role string `json:"role,omitempty"`
+type GenAIPart struct {
+ Text string `json:"text,omitempty"`
+ InlineData *GenAIInlineData `json:"inline_data,omitempty"`
}
-func (c *Content) String() string {
+type GenAIContent struct {
+ Parts []GenAIPart `json:"parts"`
+ Role string `json:"role,omitempty"`
+}
+
+func (c *GenAIContent) String() string {
sb := strings.Builder{}
for _, p := range c.Parts {
sb.WriteString(p.Text)
@@ -30,10 +36,10 @@ func (c *Content) String() string {
return sb.String()
}
-type Candidate struct {
- Content Content `json:"content"`
- FinishReason string `json:"finishReason"`
- Index int `json:"index"`
+type GenAICandidate struct {
+ Content GenAIContent `json:"content"`
+ FinishReason string `json:"finishReason"`
+ Index int `json:"index"`
}
// GenAI is Goole API format
@@ -41,8 +47,8 @@ type GenAI struct {
model string `json:"-"`
Protocol `json:"-"`
// request only
- Contents []Content `json:"contents,omitempty"`
- SystemInstruction *Content `json:"systemInstruction,omitempty"`
+ Contents []GenAIContent `json:"contents,omitempty"`
+ SystemInstruction *GenAIContent `json:"systemInstruction,omitempty"`
GenerationConfig struct {
Temperature float32 `json:"temperature,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
@@ -50,7 +56,7 @@ type GenAI struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
} `json:"generationConfig"`
// callback only
- Candidates []Candidate `json:"candidates,omitempty"`
+ Candidates []GenAICandidate `json:"candidates,omitempty"`
}
// NewGenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
@@ -85,35 +91,99 @@ func (opai *GenAI) Parse(body io.Reader) error {
return json.NewDecoder(body).Decode(&opai)
}
-func (opai *GenAI) Output() string {
+func (opai *GenAI) Output() Contents {
+ return opai.OutputRaw()
+}
+
+func (opai *GenAI) OutputRaw() Contents {
if len(opai.Candidates) == 0 {
- return ""
+ return nil
}
- return opai.Candidates[0].Content.String()
+ raw := opai.Candidates[0].Content
+ cs := make(Contents, len(raw.Parts))
+ for i, c := range raw.Parts {
+ switch {
+ case c.Text != "":
+ cs[i].Type = ContentTypeText
+ cs[i].Text = c.Text
+ case c.InlineData != nil:
+ cs[i].Type = ContentTypeImageURL
+ if strings.HasPrefix(c.InlineData.MimeType, "image/") {
+ cs[i].ImageURL = &ContentImageURL{
+ URL: "data:" + c.InlineData.MimeType + ";base64," + c.InlineData.Data,
+ }
+ }
+ default:
+ panic("unsupported genai part")
+ }
+ }
+ return cs
}
-func (opai *GenAI) OutputRaw() string {
- return opai.Output()
+func (cs Contents) ToGenAIParts() []GenAIPart {
+ ps := make([]GenAIPart, 0, len(cs))
+ for _, c := range cs {
+ switch c.Type {
+ case ContentTypeText:
+ ps = append(ps, GenAIPart{Text: c.Text})
+ case ContentTypeImageURL:
+ if strings.HasPrefix(c.ImageURL.URL, "data:") {
+ typ, dat, ok := strings.Cut(strings.TrimPrefix(c.ImageURL.URL, "data:"), ";")
+ if !ok {
+ continue
+ }
+ ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
+ MimeType: typ,
+ Data: dat[1:], // skip ;
+ }})
+ continue
+ }
+ resp, err := http.Get(c.ImageURL.URL)
+ if err != nil {
+ continue
+ }
+ defer resp.Body.Close()
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ continue
+ }
+ s, err := NewContentImageDataBase64URL(data)
+ if err != nil {
+ continue
+ }
+ typ, dat, ok := strings.Cut(strings.TrimPrefix(s, "data:"), ",")
+ if !ok {
+ continue
+ }
+ ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
+ MimeType: typ,
+ Data: dat[1:], // skip ,
+ }})
+ default:
+ panic("unsupported ContentType " + c.Type)
+ }
+ }
+ return ps
}
-func (opai *GenAI) System(prompt string) Protocol {
- opai.SystemInstruction = &Content{
- Parts: []Text{{prompt}},
+func (opai *GenAI) System(prompt ...Content) Protocol {
+ opai.SystemInstruction = &GenAIContent{
+ Parts: Contents(prompt).ToGenAIParts(),
}
return opai
}
-func (opai *GenAI) User(prompt string) Protocol {
- opai.Contents = append(opai.Contents, Content{
- Parts: []Text{{prompt}},
+func (opai *GenAI) User(prompt ...Content) Protocol {
+ opai.Contents = append(opai.Contents, GenAIContent{
+ Parts: Contents(prompt).ToGenAIParts(),
Role: "user",
})
return opai
}
-func (opai *GenAI) Assistant(prompt string) Protocol {
- opai.Contents = append(opai.Contents, Content{
- Parts: []Text{{prompt}},
+func (opai *GenAI) Assistant(prompt ...Content) Protocol {
+ opai.Contents = append(opai.Contents, GenAIContent{
+ Parts: Contents(prompt).ToGenAIParts(),
Role: "model",
})
return opai
diff --git a/model/ollama.go b/model/ollama.go
index a45779b..7eadb4b 100644
--- a/model/ollama.go
+++ b/model/ollama.go
@@ -7,22 +7,27 @@ import (
"net/http"
)
+type OLLaMAMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
// OLLaMA as an specified example.
type OLLaMA struct {
sep string
Protocol `json:"-"`
// callback only
- ID string `json:"id,omitempty"`
- Object string `json:"object,omitempty"`
- Created int `json:"created,omitempty"`
- Messages []Message `json:"messages"`
+ ID string `json:"id,omitempty"`
+ Object string `json:"object,omitempty"`
+ Created int `json:"created,omitempty"`
+ Messages []OLLaMAMessage `json:"messages"`
// callback/request
- Model string `json:"model"`
- Message *Message `json:"message,omitempty"`
- Temperature float32 `json:"temperature"` // Temperature 0.7
- TopP float32 `json:"top_p"` // TopP 0.9
- MaxTokens int `json:"max_tokens"` // MaxTokens 4096
- Stream bool `json:"stream"`
+ Model string `json:"model"`
+ Message *OLLaMAMessage `json:"message,omitempty"`
+ Temperature float32 `json:"temperature"` // Temperature 0.7
+ TopP float32 `json:"top_p"` // TopP 0.9
+ MaxTokens int `json:"max_tokens"` // MaxTokens 4096
+ Stream bool `json:"stream"`
}
// NewOLLaMA use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
@@ -58,41 +63,41 @@ func (ollm *OLLaMA) Parse(body io.Reader) error {
return json.NewDecoder(body).Decode(&ollm)
}
-func (ollm *OLLaMA) Output() string {
+func (ollm *OLLaMA) Output() Contents {
if ollm.Message == nil {
- return ""
+ return nil
}
- return CutLast(ollm.Message.Content, ollm.sep)
+ return Contents{NewContentText(CutLast(ollm.Message.Content, ollm.sep))}
}
-func (ollm *OLLaMA) OutputRaw() string {
+func (ollm *OLLaMA) OutputRaw() Contents {
if ollm.Message == nil {
- return ""
+ return nil
}
- return ollm.Message.Content
+ return Contents{NewContentText(ollm.Message.Content)}
}
-func (ollm *OLLaMA) System(prompt string) Protocol {
- ollm.Messages = make([]Message, 1, 8)
- ollm.Messages[0] = Message{
+func (ollm *OLLaMA) System(prompt ...Content) Protocol {
+ ollm.Messages = make([]OLLaMAMessage, 1, 8)
+ ollm.Messages[0] = OLLaMAMessage{
Role: "system",
- Content: prompt,
+ Content: prompt[0].Text,
}
return ollm
}
-func (ollm *OLLaMA) User(prompt string) Protocol {
- ollm.Messages = append(ollm.Messages, Message{
+func (ollm *OLLaMA) User(prompt ...Content) Protocol {
+ ollm.Messages = append(ollm.Messages, OLLaMAMessage{
Role: "user",
- Content: prompt,
+ Content: prompt[0].Text,
})
return ollm
}
-func (ollm *OLLaMA) Assistant(prompt string) Protocol {
- ollm.Messages = append(ollm.Messages, Message{
+func (ollm *OLLaMA) Assistant(prompt ...Content) Protocol {
+ ollm.Messages = append(ollm.Messages, OLLaMAMessage{
Role: "assistant",
- Content: prompt,
+ Content: prompt[0].Text,
})
return ollm
}
diff --git a/model/openai.go b/model/openai.go
index e63c944..d94d99c 100644
--- a/model/openai.go
+++ b/model/openai.go
@@ -8,18 +8,18 @@ import (
)
const (
- ModelDeepDeek = "deepseek-ai/DeepSeek-R1"
+ ModelDeepDeek = "deepseek-ai/DeepSeek-V3.1-Terminus"
)
-type Message struct {
- Role string `json:"role"`
- Content string `json:"content"`
+type OpenAIMessage struct {
+ Role string `json:"role"`
+ Content json.RawMessage `json:"content"` // Contents or string
}
-type Choice struct {
- Index int `json:"index"`
- Message Message `json:"message"`
- FinishReason string `json:"finish_reason"`
+type OpenAIChoice struct {
+ Index int `json:"index"`
+ Message OpenAIMessage `json:"message"`
+ FinishReason string `json:"finish_reason"`
}
// OpenAI as an specified example.
@@ -27,18 +27,18 @@ type OpenAI struct {
sep string
Protocol `json:"-"`
// callback only
- ID string `json:"id,omitempty"`
- Object string `json:"object,omitempty"`
- Created int `json:"created,omitempty"`
- Choices []Choice `json:"choices,omitempty"`
+ ID string `json:"id,omitempty"`
+ Object string `json:"object,omitempty"`
+ Created int `json:"created,omitempty"`
+ Choices []OpenAIChoice `json:"choices,omitempty"`
// callback/request
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
- TopP float32 `json:"top_p,omitempty"` // TopP 0.9
- MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
+ Model string `json:"model"`
+ Messages []OpenAIMessage `json:"messages"`
+ Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
+ TopP float32 `json:"top_p,omitempty"` // TopP 0.9
+ MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
// extra body
- EnableThinking bool `json:"enable_thinking"`
+ EnableThinking bool `json:"enable_thinking"` // EnableThinking is always false in non-stream mode, adapt to 百炼
}
// NewOpenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
@@ -74,41 +74,80 @@ func (opai *OpenAI) Parse(body io.Reader) error {
return json.NewDecoder(body).Decode(&opai)
}
-func (opai *OpenAI) Output() string {
+func (opai *OpenAI) Output() Contents {
if len(opai.Choices) == 0 {
- return ""
+ return nil
}
- return CutLast(opai.Choices[len(opai.Choices)-1].Message.Content, opai.sep)
+ cs := make(Contents, 0, 8)
+ c := opai.Choices[len(opai.Choices)-1].Message.Content
+ err := json.Unmarshal(c, &cs)
+ if err == nil {
+ for i := range cs {
+ if cs[i].Type == ContentTypeText {
+ cs[i].Text = CutLast(cs[i].Text, opai.sep)
+ }
+ }
+ return cs
+ }
+ s := ""
+ err = json.Unmarshal(c, &s)
+ if err == nil {
+ return Contents{{Type: ContentTypeText, Text: CutLast(s, opai.sep)}}
+ }
+ return nil
}
-func (opai *OpenAI) OutputRaw() string {
+func (opai *OpenAI) OutputRaw() Contents {
if len(opai.Choices) == 0 {
- return ""
+ return nil
}
- return opai.Choices[len(opai.Choices)-1].Message.Content
+ cs := make(Contents, 0, 8)
+ c := opai.Choices[len(opai.Choices)-1].Message.Content
+ err := json.Unmarshal(c, &cs)
+ if err == nil {
+ return cs
+ }
+ s := ""
+ err = json.Unmarshal(c, &s)
+ if err == nil {
+ return Contents{{Type: ContentTypeText, Text: s}}
+ }
+ return nil
}
-func (opai *OpenAI) System(prompt string) Protocol {
- opai.Messages = make([]Message, 1, 8)
- opai.Messages[0] = Message{
+func (opai *OpenAI) System(prompt ...Content) Protocol {
+ opai.Messages = make([]OpenAIMessage, 1, 8)
+ raw, err := json.Marshal(&prompt)
+ if err != nil {
+ panic(err)
+ }
+ opai.Messages[0] = OpenAIMessage{
Role: "system",
- Content: prompt,
+ Content: raw,
}
return opai
}
-func (opai *OpenAI) User(prompt string) Protocol {
- opai.Messages = append(opai.Messages, Message{
+func (opai *OpenAI) User(prompt ...Content) Protocol {
+ raw, err := json.Marshal(&prompt)
+ if err != nil {
+ panic(err)
+ }
+ opai.Messages = append(opai.Messages, OpenAIMessage{
Role: "user",
- Content: prompt,
+ Content: raw,
})
return opai
}
-func (opai *OpenAI) Assistant(prompt string) Protocol {
- opai.Messages = append(opai.Messages, Message{
+func (opai *OpenAI) Assistant(prompt ...Content) Protocol {
+ raw, err := json.Marshal(&prompt)
+ if err != nil {
+ panic(err)
+ }
+ opai.Messages = append(opai.Messages, OpenAIMessage{
Role: "assistant",
- Content: prompt,
+ Content: raw,
})
return opai
}