mirror of
https://github.com/fumiama/deepinfra.git
synced 2026-06-05 00:32:46 +08:00
feat: support image upload
This commit is contained in:
2
api.go
2
api.go
@@ -52,5 +52,5 @@ func (api *API) Request(model Model) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return model.Output(), nil
|
||||
return model.Output().String(), nil
|
||||
}
|
||||
|
||||
10
chat/chat.go
10
chat/chat.go
@@ -63,24 +63,24 @@ func (l *Log[T]) Add(grp int64, item T, isbot bool) {
|
||||
func (l *Log[T]) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model {
|
||||
m := p
|
||||
if sysp != "" && !isusersystem {
|
||||
m.System(sysp)
|
||||
m.System(model.NewContentText(sysp))
|
||||
}
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
sz := len(l.m[grp])
|
||||
if sz == 0 {
|
||||
return m.User(l.defaultprompt)
|
||||
return m.User(model.NewContentText(l.defaultprompt))
|
||||
}
|
||||
for i, msg := range l.m[grp] {
|
||||
if i%2 == 0 { // is user
|
||||
if i == 0 && isusersystem {
|
||||
_ = m.User(sysp + "\n\n" + msg.String())
|
||||
_ = m.User(model.NewContentText(sysp + "\n\n" + msg.String()))
|
||||
continue
|
||||
}
|
||||
_ = m.User(msg.String())
|
||||
_ = m.User(model.NewContentText(msg.String()))
|
||||
continue
|
||||
}
|
||||
_ = m.Assistant(msg.String())
|
||||
_ = m.Assistant(model.NewContentText(msg.String()))
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
18
go.mod
18
go.mod
@@ -1,3 +1,21 @@
|
||||
module github.com/fumiama/deepinfra
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/FloatTech/gg v1.1.3
|
||||
github.com/FloatTech/imgfactory v0.2.1
|
||||
github.com/fumiama/imgsz v0.0.4
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/disintegration/imaging v1.6.2 // indirect
|
||||
github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 // indirect
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.9.1
|
||||
golang.org/x/image v0.16.0 // indirect
|
||||
)
|
||||
|
||||
20
go.sum
Normal file
20
go.sum
Normal file
@@ -0,0 +1,20 @@
|
||||
github.com/FloatTech/gg v1.1.3 h1:+GlL02lTKsxJQr4WCuNwVxC1/eBZrCvypCIBtxuOFb4=
|
||||
github.com/FloatTech/gg v1.1.3/go.mod h1:/9oLP54CMfq4r+71XL26uaFTJ1uL1boAyX67680/1HE=
|
||||
github.com/FloatTech/imgfactory v0.2.1 h1:XoVwy0Xu0AvTRtzlhv5teZcuZlAcHrYjeQ8MynJ/zlk=
|
||||
github.com/FloatTech/imgfactory v0.2.1/go.mod h1:QBJKHbzpE+x/9Wn7mXebWap/K/xUJSjgiaelAElwU9Q=
|
||||
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
|
||||
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
|
||||
github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 h1:BBade+JlV/f7JstZ4pitd4tHhpN+w+6I+LyOS7B4fyU=
|
||||
github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4/go.mod h1:H7chHJglrhPPzetLdzBleF8d22WYOv7UM/lEKYiwlKM=
|
||||
github.com/fumiama/imgsz v0.0.4 h1:Lsasu2hdSSFS+vnD+nvR1UkiRMK7hcpyYCC0FzgSMFI=
|
||||
github.com/fumiama/imgsz v0.0.4/go.mod h1:bISOQVTlw9sRytPwe8ir7tAaEmyz9hSNj9n8mXMBG0E=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
|
||||
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
10
model/api.go
10
model/api.go
@@ -12,8 +12,8 @@ type Inputer interface {
|
||||
}
|
||||
|
||||
type Outputer interface {
|
||||
Output() string
|
||||
OutputRaw() string
|
||||
Output() Contents
|
||||
OutputRaw() Contents
|
||||
}
|
||||
|
||||
type Requester interface {
|
||||
@@ -22,9 +22,9 @@ type Requester interface {
|
||||
}
|
||||
|
||||
type MessageBuilder[T any] interface {
|
||||
System(prompt string) T
|
||||
User(prompt string) T
|
||||
Assistant(prompt string) T
|
||||
System(prompt ...Content) T
|
||||
User(prompt ...Content) T
|
||||
Assistant(prompt ...Content) T
|
||||
}
|
||||
|
||||
type Protocol interface {
|
||||
|
||||
93
model/content.go
Normal file
93
model/content.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime"
|
||||
"strings"
|
||||
|
||||
"github.com/fumiama/imgsz"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedImageExtension = errors.New("unsupported image extension")
|
||||
)
|
||||
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeText ContentType = "text"
|
||||
ContentTypeImageURL ContentType = "image_url"
|
||||
)
|
||||
|
||||
type Contents []Content
|
||||
|
||||
func (cs Contents) String() string {
|
||||
switch len(cs) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
if cs[0].Type == ContentTypeText {
|
||||
return cs[0].Text
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
sb := strings.Builder{}
|
||||
for _, c := range cs {
|
||||
switch c.Type {
|
||||
case ContentTypeText:
|
||||
sb.WriteString("<text>")
|
||||
sb.WriteString(c.Text)
|
||||
sb.WriteString("</text>")
|
||||
case ContentTypeImageURL:
|
||||
if c.ImageURL != nil {
|
||||
sb.WriteString("<image_url>")
|
||||
sb.WriteString(c.ImageURL.URL)
|
||||
sb.WriteString("</image_url>")
|
||||
}
|
||||
default:
|
||||
panic("unsupported ContentType " + c.Type)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Type ContentType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ContentImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func NewContentText(txt string) Content {
|
||||
return Content{Type: ContentTypeText, Text: txt}
|
||||
}
|
||||
|
||||
type ContentImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func NewContentImageURL(u string) Content {
|
||||
return Content{Type: ContentTypeImageURL, ImageURL: &ContentImageURL{URL: u}}
|
||||
}
|
||||
|
||||
func NewContentImageDataBase64URL(data []byte) (string, error) {
|
||||
_, ext, err := imgsz.DecodeSize(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
typ := mime.TypeByExtension("." + ext)
|
||||
if len(typ) == 0 {
|
||||
return "", errors.Wrap(ErrUnsupportedImageExtension, ext)
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString("data:")
|
||||
sb.WriteString(typ)
|
||||
sb.WriteString(";base64,")
|
||||
enc := base64.NewEncoder(base64.StdEncoding, &sb)
|
||||
_, _ = io.Copy(enc, bytes.NewReader(data))
|
||||
_ = enc.Close()
|
||||
return sb.String(), nil
|
||||
}
|
||||
33
model/content_test.go
Normal file
33
model/content_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const contentTextSmallImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
||||
|
||||
func TestContentImageDataBytes(t *testing.T) {
|
||||
// 从 smallimg 中提取 base64 数据
|
||||
parts := strings.Split(contentTextSmallImage, ",")
|
||||
if len(parts) != 2 {
|
||||
t.Fatal("Invalid data URL format")
|
||||
}
|
||||
|
||||
// 解码 base64 数据
|
||||
data, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
s, err := NewContentImageDataBase64URL(data)
|
||||
if err != nil {
|
||||
t.Fatalf("NewContentImageDataBytes failed: %v", err)
|
||||
}
|
||||
|
||||
// 比较结果
|
||||
if s != contentTextSmallImage {
|
||||
t.Errorf("Expected %s, got %s", contentTextSmallImage, s)
|
||||
}
|
||||
}
|
||||
124
model/genai.go
124
model/genai.go
@@ -13,16 +13,22 @@ const (
|
||||
ModelGemini15Flash = "models/gemini-1.5-flash"
|
||||
)
|
||||
|
||||
type Text struct {
|
||||
Text string `json:"text"`
|
||||
type GenAIInlineData struct {
|
||||
MimeType string `json:"mime_type"`
|
||||
Data string `json:"data"` // Data is base64 repr
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Parts []Text `json:"parts"`
|
||||
Role string `json:"role,omitempty"`
|
||||
type GenAIPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GenAIInlineData `json:"inline_data,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Content) String() string {
|
||||
type GenAIContent struct {
|
||||
Parts []GenAIPart `json:"parts"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
func (c *GenAIContent) String() string {
|
||||
sb := strings.Builder{}
|
||||
for _, p := range c.Parts {
|
||||
sb.WriteString(p.Text)
|
||||
@@ -30,10 +36,10 @@ func (c *Content) String() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
type GenAICandidate struct {
|
||||
Content GenAIContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// GenAI is Goole API format
|
||||
@@ -41,8 +47,8 @@ type GenAI struct {
|
||||
model string `json:"-"`
|
||||
Protocol `json:"-"`
|
||||
// request only
|
||||
Contents []Content `json:"contents,omitempty"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
Contents []GenAIContent `json:"contents,omitempty"`
|
||||
SystemInstruction *GenAIContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
@@ -50,7 +56,7 @@ type GenAI struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
} `json:"generationConfig"`
|
||||
// callback only
|
||||
Candidates []Candidate `json:"candidates,omitempty"`
|
||||
Candidates []GenAICandidate `json:"candidates,omitempty"`
|
||||
}
|
||||
|
||||
// NewGenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -85,35 +91,99 @@ func (opai *GenAI) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&opai)
|
||||
}
|
||||
|
||||
func (opai *GenAI) Output() string {
|
||||
func (opai *GenAI) Output() Contents {
|
||||
return opai.OutputRaw()
|
||||
}
|
||||
|
||||
func (opai *GenAI) OutputRaw() Contents {
|
||||
if len(opai.Candidates) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return opai.Candidates[0].Content.String()
|
||||
raw := opai.Candidates[0].Content
|
||||
cs := make(Contents, len(raw.Parts))
|
||||
for i, c := range raw.Parts {
|
||||
switch {
|
||||
case c.Text != "":
|
||||
cs[i].Type = ContentTypeText
|
||||
cs[i].Text = c.Text
|
||||
case c.InlineData != nil:
|
||||
cs[i].Type = ContentTypeImageURL
|
||||
if strings.HasPrefix(c.InlineData.MimeType, "image/") {
|
||||
cs[i].ImageURL = &ContentImageURL{
|
||||
URL: "data:" + c.InlineData.MimeType + ";base64," + c.InlineData.Data,
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("unsupported genai part")
|
||||
}
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
func (opai *GenAI) OutputRaw() string {
|
||||
return opai.Output()
|
||||
func (cs Contents) ToGenAIParts() []GenAIPart {
|
||||
ps := make([]GenAIPart, 0, len(cs))
|
||||
for _, c := range cs {
|
||||
switch c.Type {
|
||||
case ContentTypeText:
|
||||
ps = append(ps, GenAIPart{Text: c.Text})
|
||||
case ContentTypeImageURL:
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:") {
|
||||
typ, dat, ok := strings.Cut(strings.TrimPrefix(c.ImageURL.URL, "data:"), ";")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
|
||||
MimeType: typ,
|
||||
Data: dat[1:], // skip ;
|
||||
}})
|
||||
continue
|
||||
}
|
||||
resp, err := http.Get(c.ImageURL.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s, err := NewContentImageDataBase64URL(data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
typ, dat, ok := strings.Cut(strings.TrimPrefix(s, "data:"), ",")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
|
||||
MimeType: typ,
|
||||
Data: dat[1:], // skip ,
|
||||
}})
|
||||
default:
|
||||
panic("unsupported ContentType " + c.Type)
|
||||
}
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (opai *GenAI) System(prompt string) Protocol {
|
||||
opai.SystemInstruction = &Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) System(prompt ...Content) Protocol {
|
||||
opai.SystemInstruction = &GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
}
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *GenAI) User(prompt string) Protocol {
|
||||
opai.Contents = append(opai.Contents, Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) User(prompt ...Content) Protocol {
|
||||
opai.Contents = append(opai.Contents, GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
Role: "user",
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *GenAI) Assistant(prompt string) Protocol {
|
||||
opai.Contents = append(opai.Contents, Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) Assistant(prompt ...Content) Protocol {
|
||||
opai.Contents = append(opai.Contents, GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
Role: "model",
|
||||
})
|
||||
return opai
|
||||
|
||||
@@ -7,22 +7,27 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OLLaMAMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OLLaMA as an specified example.
|
||||
type OLLaMA struct {
|
||||
sep string
|
||||
Protocol `json:"-"`
|
||||
// callback only
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Messages []Message `json:"messages"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Messages []OLLaMAMessage `json:"messages"`
|
||||
// callback/request
|
||||
Model string `json:"model"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Temperature float32 `json:"temperature"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens"` // MaxTokens 4096
|
||||
Stream bool `json:"stream"`
|
||||
Model string `json:"model"`
|
||||
Message *OLLaMAMessage `json:"message,omitempty"`
|
||||
Temperature float32 `json:"temperature"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens"` // MaxTokens 4096
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// NewOLLaMA use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -58,41 +63,41 @@ func (ollm *OLLaMA) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&ollm)
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) Output() string {
|
||||
func (ollm *OLLaMA) Output() Contents {
|
||||
if ollm.Message == nil {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return CutLast(ollm.Message.Content, ollm.sep)
|
||||
return Contents{NewContentText(CutLast(ollm.Message.Content, ollm.sep))}
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) OutputRaw() string {
|
||||
func (ollm *OLLaMA) OutputRaw() Contents {
|
||||
if ollm.Message == nil {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return ollm.Message.Content
|
||||
return Contents{NewContentText(ollm.Message.Content)}
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) System(prompt string) Protocol {
|
||||
ollm.Messages = make([]Message, 1, 8)
|
||||
ollm.Messages[0] = Message{
|
||||
func (ollm *OLLaMA) System(prompt ...Content) Protocol {
|
||||
ollm.Messages = make([]OLLaMAMessage, 1, 8)
|
||||
ollm.Messages[0] = OLLaMAMessage{
|
||||
Role: "system",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
}
|
||||
return ollm
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) User(prompt string) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, Message{
|
||||
func (ollm *OLLaMA) User(prompt ...Content) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, OLLaMAMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
})
|
||||
return ollm
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) Assistant(prompt string) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, Message{
|
||||
func (ollm *OLLaMA) Assistant(prompt ...Content) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, OLLaMAMessage{
|
||||
Role: "assistant",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
})
|
||||
return ollm
|
||||
}
|
||||
|
||||
107
model/openai.go
107
model/openai.go
@@ -8,18 +8,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ModelDeepDeek = "deepseek-ai/DeepSeek-R1"
|
||||
ModelDeepDeek = "deepseek-ai/DeepSeek-V3.1-Terminus"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"` // Contents or string
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
type OpenAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message OpenAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// OpenAI as an specified example.
|
||||
@@ -27,18 +27,18 @@ type OpenAI struct {
|
||||
sep string
|
||||
Protocol `json:"-"`
|
||||
// callback only
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Choices []OpenAIChoice `json:"choices,omitempty"`
|
||||
// callback/request
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p,omitempty"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
|
||||
Model string `json:"model"`
|
||||
Messages []OpenAIMessage `json:"messages"`
|
||||
Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p,omitempty"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
|
||||
// extra body
|
||||
EnableThinking bool `json:"enable_thinking"`
|
||||
EnableThinking bool `json:"enable_thinking"` // EnableThinking is always false in non-stream mode, adapt to 百炼
|
||||
}
|
||||
|
||||
// NewOpenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -74,41 +74,80 @@ func (opai *OpenAI) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&opai)
|
||||
}
|
||||
|
||||
func (opai *OpenAI) Output() string {
|
||||
func (opai *OpenAI) Output() Contents {
|
||||
if len(opai.Choices) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return CutLast(opai.Choices[len(opai.Choices)-1].Message.Content, opai.sep)
|
||||
cs := make(Contents, 0, 8)
|
||||
c := opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
err := json.Unmarshal(c, &cs)
|
||||
if err == nil {
|
||||
for i := range cs {
|
||||
if cs[i].Type == ContentTypeText {
|
||||
cs[i].Text = CutLast(cs[i].Text, opai.sep)
|
||||
}
|
||||
}
|
||||
return cs
|
||||
}
|
||||
s := ""
|
||||
err = json.Unmarshal(c, &s)
|
||||
if err == nil {
|
||||
return Contents{{Type: ContentTypeText, Text: CutLast(s, opai.sep)}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opai *OpenAI) OutputRaw() string {
|
||||
func (opai *OpenAI) OutputRaw() Contents {
|
||||
if len(opai.Choices) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
cs := make(Contents, 0, 8)
|
||||
c := opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
err := json.Unmarshal(c, &cs)
|
||||
if err == nil {
|
||||
return cs
|
||||
}
|
||||
s := ""
|
||||
err = json.Unmarshal(c, &s)
|
||||
if err == nil {
|
||||
return Contents{{Type: ContentTypeText, Text: s}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opai *OpenAI) System(prompt string) Protocol {
|
||||
opai.Messages = make([]Message, 1, 8)
|
||||
opai.Messages[0] = Message{
|
||||
func (opai *OpenAI) System(prompt ...Content) Protocol {
|
||||
opai.Messages = make([]OpenAIMessage, 1, 8)
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages[0] = OpenAIMessage{
|
||||
Role: "system",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
}
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *OpenAI) User(prompt string) Protocol {
|
||||
opai.Messages = append(opai.Messages, Message{
|
||||
func (opai *OpenAI) User(prompt ...Content) Protocol {
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages = append(opai.Messages, OpenAIMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *OpenAI) Assistant(prompt string) Protocol {
|
||||
opai.Messages = append(opai.Messages, Message{
|
||||
func (opai *OpenAI) Assistant(prompt ...Content) Protocol {
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages = append(opai.Messages, OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user