diff --git a/api.go b/api.go index 62290ea..cf12ef0 100644 --- a/api.go +++ b/api.go @@ -52,5 +52,5 @@ func (api *API) Request(model Model) (string, error) { if err != nil { return "", err } - return model.Output(), nil + return model.Output().String(), nil } diff --git a/chat/chat.go b/chat/chat.go index 5377f41..1280679 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -63,24 +63,24 @@ func (l *Log[T]) Add(grp int64, item T, isbot bool) { func (l *Log[T]) Modelize(p model.Protocol, grp int64, sysp string, isusersystem bool) deepinfra.Model { m := p if sysp != "" && !isusersystem { - m.System(sysp) + m.System(model.NewContentText(sysp)) } l.mu.RLock() defer l.mu.RUnlock() sz := len(l.m[grp]) if sz == 0 { - return m.User(l.defaultprompt) + return m.User(model.NewContentText(l.defaultprompt)) } for i, msg := range l.m[grp] { if i%2 == 0 { // is user if i == 0 && isusersystem { - _ = m.User(sysp + "\n\n" + msg.String()) + _ = m.User(model.NewContentText(sysp + "\n\n" + msg.String())) continue } - _ = m.User(msg.String()) + _ = m.User(model.NewContentText(msg.String())) continue } - _ = m.Assistant(msg.String()) + _ = m.Assistant(model.NewContentText(msg.String())) } return m } diff --git a/go.mod b/go.mod index fe83726..ed7e172 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,21 @@ module github.com/fumiama/deepinfra go 1.20 + +require ( + github.com/FloatTech/gg v1.1.3 + github.com/FloatTech/imgfactory v0.2.1 + github.com/fumiama/imgsz v0.0.4 +) + +require ( + github.com/disintegration/imaging v1.6.2 // indirect + github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 // indirect + github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect + golang.org/x/text v0.15.0 // indirect +) + +require ( + github.com/pkg/errors v0.9.1 + golang.org/x/image v0.16.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..370c3aa --- /dev/null +++ b/go.sum @@ -0,0 +1,20 @@ +github.com/FloatTech/gg v1.1.3 h1:+GlL02lTKsxJQr4WCuNwVxC1/eBZrCvypCIBtxuOFb4= +github.com/FloatTech/gg v1.1.3/go.mod h1:/9oLP54CMfq4r+71XL26uaFTJ1uL1boAyX67680/1HE= +github.com/FloatTech/imgfactory v0.2.1 h1:XoVwy0Xu0AvTRtzlhv5teZcuZlAcHrYjeQ8MynJ/zlk= +github.com/FloatTech/imgfactory v0.2.1/go.mod h1:QBJKHbzpE+x/9Wn7mXebWap/K/xUJSjgiaelAElwU9Q= +github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= +github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= +github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4 h1:BBade+JlV/f7JstZ4pitd4tHhpN+w+6I+LyOS7B4fyU= +github.com/ericpauley/go-quantize v0.0.0-20200331213906-ae555eb2afa4/go.mod h1:H7chHJglrhPPzetLdzBleF8d22WYOv7UM/lEKYiwlKM= +github.com/fumiama/imgsz v0.0.4 h1:Lsasu2hdSSFS+vnD+nvR1UkiRMK7hcpyYCC0FzgSMFI= +github.com/fumiama/imgsz v0.0.4/go.mod h1:bISOQVTlw9sRytPwe8ir7tAaEmyz9hSNj9n8mXMBG0E= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= +golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/model/api.go b/model/api.go index a73af94..d784e8f 100644 --- a/model/api.go +++ b/model/api.go @@ -12,8 +12,8 @@ type Inputer interface { } type Outputer interface { - Output() string - OutputRaw() string + Output() Contents + OutputRaw() Contents } type Requester interface { @@ -22,9 +22,9 @@ type Requester interface { } type MessageBuilder[T any] interface { - System(prompt string) T - User(prompt string) T - Assistant(prompt string) T + System(prompt ...Content) T + User(prompt ...Content) T + Assistant(prompt ...Content) T } type Protocol interface { diff --git a/model/content.go b/model/content.go new file mode 100644 index 0000000..a1d2485 --- /dev/null +++ b/model/content.go @@ -0,0 +1,93 @@ +package model + +import ( + "bytes" + "encoding/base64" + "io" + "mime" + "strings" + + "github.com/fumiama/imgsz" + "github.com/pkg/errors" +) + +var ( + ErrUnsupportedImageExtension = errors.New("unsupported image extension") +) + +type ContentType string + +const ( + ContentTypeText ContentType = "text" + ContentTypeImageURL ContentType = "image_url" +) + +type Contents []Content + +func (cs Contents) String() string { + switch len(cs) { + case 0: + return "" + case 1: + if cs[0].Type == ContentTypeText { + return cs[0].Text + } + fallthrough + default: + sb := strings.Builder{} + for _, c := range cs { + switch c.Type { + case ContentTypeText: + sb.WriteString("") + sb.WriteString(c.Text) + sb.WriteString("") + case ContentTypeImageURL: + if c.ImageURL != nil { + sb.WriteString("") + sb.WriteString(c.ImageURL.URL) + sb.WriteString("") + } + default: + panic("unsupported ContentType " + c.Type) + } + } + return sb.String() + } +} + +type Content struct { + Type ContentType `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ContentImageURL `json:"image_url,omitempty"` +} + +func NewContentText(txt string) Content { + return Content{Type: ContentTypeText, Text: txt} +} + +type ContentImageURL struct { + URL string `json:"url"` +} + +func NewContentImageURL(u string) Content { + return Content{Type: ContentTypeImageURL, ImageURL: &ContentImageURL{URL: u}} +} + +func NewContentImageDataBase64URL(data []byte) (string, error) { + _, ext, err := imgsz.DecodeSize(bytes.NewReader(data)) + if err != nil { + return "", err + } + typ := mime.TypeByExtension("." + ext) + if len(typ) == 0 { + return "", errors.Wrap(ErrUnsupportedImageExtension, ext) + } + sb := strings.Builder{} + sb.WriteString("data:") + sb.WriteString(typ) + sb.WriteString(";base64,") + enc := base64.NewEncoder(base64.StdEncoding, &sb) + _, _ = io.Copy(enc, bytes.NewReader(data)) + _ = enc.Close() + return sb.String(), nil +} diff --git a/model/content_test.go b/model/content_test.go new file mode 100644 index 0000000..10a278b --- /dev/null +++ b/model/content_test.go @@ -0,0 +1,33 @@ +package model + +import ( + "encoding/base64" + "strings" + "testing" +) + +const contentTextSmallImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + +func TestContentImageDataBytes(t *testing.T) { + // 从 smallimg 中提取 base64 数据 + parts := strings.Split(contentTextSmallImage, ",") + if len(parts) != 2 { + t.Fatal("Invalid data URL format") + } + + // 解码 base64 数据 + data, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("Failed to decode base64: %v", err) + } + + s, err := NewContentImageDataBase64URL(data) + if err != nil { + t.Fatalf("NewContentImageDataBytes failed: %v", err) + } + + // 比较结果 + if s != contentTextSmallImage { + t.Errorf("Expected %s, got %s", contentTextSmallImage, s) + } +} diff --git a/model/genai.go b/model/genai.go index b2b66b9..cc26e4a 100644 --- a/model/genai.go +++ b/model/genai.go @@ -13,16 +13,22 @@ const ( ModelGemini15Flash = "models/gemini-1.5-flash" ) -type Text struct { - Text string `json:"text"` +type GenAIInlineData struct { + MimeType string `json:"mime_type"` + Data string `json:"data"` // Data is base64 repr } -type Content struct { - Parts []Text `json:"parts"` - Role string `json:"role,omitempty"` +type GenAIPart struct { + Text string `json:"text,omitempty"` + InlineData *GenAIInlineData `json:"inline_data,omitempty"` } -func (c *Content) String() string { +type GenAIContent struct { + Parts []GenAIPart `json:"parts"` + Role string `json:"role,omitempty"` +} + +func (c *GenAIContent) String() string { sb := strings.Builder{} for _, p := range c.Parts { sb.WriteString(p.Text) @@ -30,10 +36,10 @@ func (c *Content) String() string { return sb.String() } -type Candidate struct { - Content Content `json:"content"` - FinishReason string `json:"finishReason"` - Index int `json:"index"` +type GenAICandidate struct { + Content GenAIContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` } // GenAI is Goole API format @@ -41,8 +47,8 @@ type GenAI struct { model string `json:"-"` Protocol `json:"-"` // request only - Contents []Content `json:"contents,omitempty"` - SystemInstruction *Content `json:"systemInstruction,omitempty"` + Contents []GenAIContent `json:"contents,omitempty"` + SystemInstruction *GenAIContent `json:"systemInstruction,omitempty"` GenerationConfig struct { Temperature float32 `json:"temperature,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"` @@ -50,7 +56,7 @@ type GenAI struct { MaxOutputTokens int `json:"maxOutputTokens,omitempty"` } `json:"generationConfig"` // callback only - Candidates []Candidate `json:"candidates,omitempty"` + Candidates []GenAICandidate `json:"candidates,omitempty"` } // NewGenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning. @@ -85,35 +91,99 @@ func (opai *GenAI) Parse(body io.Reader) error { return json.NewDecoder(body).Decode(&opai) } -func (opai *GenAI) Output() string { +func (opai *GenAI) Output() Contents { + return opai.OutputRaw() +} + +func (opai *GenAI) OutputRaw() Contents { if len(opai.Candidates) == 0 { - return "" + return nil } - return opai.Candidates[0].Content.String() + raw := opai.Candidates[0].Content + cs := make(Contents, len(raw.Parts)) + for i, c := range raw.Parts { + switch { + case c.Text != "": + cs[i].Type = ContentTypeText + cs[i].Text = c.Text + case c.InlineData != nil: + cs[i].Type = ContentTypeImageURL + if strings.HasPrefix(c.InlineData.MimeType, "image/") { + cs[i].ImageURL = &ContentImageURL{ + URL: "data:" + c.InlineData.MimeType + ";base64," + c.InlineData.Data, + } + } + default: + panic("unsupported genai part") + } + } + return cs } -func (opai *GenAI) OutputRaw() string { - return opai.Output() +func (cs Contents) ToGenAIParts() []GenAIPart { + ps := make([]GenAIPart, 0, len(cs)) + for _, c := range cs { + switch c.Type { + case ContentTypeText: + ps = append(ps, GenAIPart{Text: c.Text}) + case ContentTypeImageURL: + if strings.HasPrefix(c.ImageURL.URL, "data:") { + typ, dat, ok := strings.Cut(strings.TrimPrefix(c.ImageURL.URL, "data:"), ";") + if !ok { + continue + } + ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{ + MimeType: typ, + Data: dat[1:], // skip ; + }}) + continue + } + resp, err := http.Get(c.ImageURL.URL) + if err != nil { + continue + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + continue + } + s, err := NewContentImageDataBase64URL(data) + if err != nil { + continue + } + typ, dat, ok := strings.Cut(strings.TrimPrefix(s, "data:"), ",") + if !ok { + continue + } + ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{ + MimeType: typ, + Data: dat[1:], // skip , + }}) + default: + panic("unsupported ContentType " + c.Type) + } + } + return ps } -func (opai *GenAI) System(prompt string) Protocol { - opai.SystemInstruction = &Content{ - Parts: []Text{{prompt}}, +func (opai *GenAI) System(prompt ...Content) Protocol { + opai.SystemInstruction = &GenAIContent{ + Parts: Contents(prompt).ToGenAIParts(), } return opai } -func (opai *GenAI) User(prompt string) Protocol { - opai.Contents = append(opai.Contents, Content{ - Parts: []Text{{prompt}}, +func (opai *GenAI) User(prompt ...Content) Protocol { + opai.Contents = append(opai.Contents, GenAIContent{ + Parts: Contents(prompt).ToGenAIParts(), Role: "user", }) return opai } -func (opai *GenAI) Assistant(prompt string) Protocol { - opai.Contents = append(opai.Contents, Content{ - Parts: []Text{{prompt}}, +func (opai *GenAI) Assistant(prompt ...Content) Protocol { + opai.Contents = append(opai.Contents, GenAIContent{ + Parts: Contents(prompt).ToGenAIParts(), Role: "model", }) return opai diff --git a/model/ollama.go b/model/ollama.go index a45779b..7eadb4b 100644 --- a/model/ollama.go +++ b/model/ollama.go @@ -7,22 +7,27 @@ import ( "net/http" ) +type OLLaMAMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + // OLLaMA as an specified example. type OLLaMA struct { sep string Protocol `json:"-"` // callback only - ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` - Created int `json:"created,omitempty"` - Messages []Message `json:"messages"` + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Messages []OLLaMAMessage `json:"messages"` // callback/request - Model string `json:"model"` - Message *Message `json:"message,omitempty"` - Temperature float32 `json:"temperature"` // Temperature 0.7 - TopP float32 `json:"top_p"` // TopP 0.9 - MaxTokens int `json:"max_tokens"` // MaxTokens 4096 - Stream bool `json:"stream"` + Model string `json:"model"` + Message *OLLaMAMessage `json:"message,omitempty"` + Temperature float32 `json:"temperature"` // Temperature 0.7 + TopP float32 `json:"top_p"` // TopP 0.9 + MaxTokens int `json:"max_tokens"` // MaxTokens 4096 + Stream bool `json:"stream"` } // NewOLLaMA use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning. @@ -58,41 +63,41 @@ func (ollm *OLLaMA) Parse(body io.Reader) error { return json.NewDecoder(body).Decode(&ollm) } -func (ollm *OLLaMA) Output() string { +func (ollm *OLLaMA) Output() Contents { if ollm.Message == nil { - return "" + return nil } - return CutLast(ollm.Message.Content, ollm.sep) + return Contents{NewContentText(CutLast(ollm.Message.Content, ollm.sep))} } -func (ollm *OLLaMA) OutputRaw() string { +func (ollm *OLLaMA) OutputRaw() Contents { if ollm.Message == nil { - return "" + return nil } - return ollm.Message.Content + return Contents{NewContentText(ollm.Message.Content)} } -func (ollm *OLLaMA) System(prompt string) Protocol { - ollm.Messages = make([]Message, 1, 8) - ollm.Messages[0] = Message{ +func (ollm *OLLaMA) System(prompt ...Content) Protocol { + ollm.Messages = make([]OLLaMAMessage, 1, 8) + ollm.Messages[0] = OLLaMAMessage{ Role: "system", - Content: prompt, + Content: prompt[0].Text, } return ollm } -func (ollm *OLLaMA) User(prompt string) Protocol { - ollm.Messages = append(ollm.Messages, Message{ +func (ollm *OLLaMA) User(prompt ...Content) Protocol { + ollm.Messages = append(ollm.Messages, OLLaMAMessage{ Role: "user", - Content: prompt, + Content: prompt[0].Text, }) return ollm } -func (ollm *OLLaMA) Assistant(prompt string) Protocol { - ollm.Messages = append(ollm.Messages, Message{ +func (ollm *OLLaMA) Assistant(prompt ...Content) Protocol { + ollm.Messages = append(ollm.Messages, OLLaMAMessage{ Role: "assistant", - Content: prompt, + Content: prompt[0].Text, }) return ollm } diff --git a/model/openai.go b/model/openai.go index e63c944..d94d99c 100644 --- a/model/openai.go +++ b/model/openai.go @@ -8,18 +8,18 @@ import ( ) const ( - ModelDeepDeek = "deepseek-ai/DeepSeek-R1" + ModelDeepDeek = "deepseek-ai/DeepSeek-V3.1-Terminus" ) -type Message struct { - Role string `json:"role"` - Content string `json:"content"` +type OpenAIMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` // Contents or string } -type Choice struct { - Index int `json:"index"` - Message Message `json:"message"` - FinishReason string `json:"finish_reason"` +type OpenAIChoice struct { + Index int `json:"index"` + Message OpenAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` } // OpenAI as an specified example. @@ -27,18 +27,18 @@ type OpenAI struct { sep string Protocol `json:"-"` // callback only - ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` - Created int `json:"created,omitempty"` - Choices []Choice `json:"choices,omitempty"` + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Choices []OpenAIChoice `json:"choices,omitempty"` // callback/request - Model string `json:"model"` - Messages []Message `json:"messages"` - Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7 - TopP float32 `json:"top_p,omitempty"` // TopP 0.9 - MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096 + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7 + TopP float32 `json:"top_p,omitempty"` // TopP 0.9 + MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096 // extra body - EnableThinking bool `json:"enable_thinking"` + EnableThinking bool `json:"enable_thinking"` // EnableThinking is always false in non-stream mode, adapt to 百炼 } // NewOpenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning. @@ -74,41 +74,80 @@ func (opai *OpenAI) Parse(body io.Reader) error { return json.NewDecoder(body).Decode(&opai) } -func (opai *OpenAI) Output() string { +func (opai *OpenAI) Output() Contents { if len(opai.Choices) == 0 { - return "" + return nil } - return CutLast(opai.Choices[len(opai.Choices)-1].Message.Content, opai.sep) + cs := make(Contents, 0, 8) + c := opai.Choices[len(opai.Choices)-1].Message.Content + err := json.Unmarshal(c, &cs) + if err == nil { + for i := range cs { + if cs[i].Type == ContentTypeText { + cs[i].Text = CutLast(cs[i].Text, opai.sep) + } + } + return cs + } + s := "" + err = json.Unmarshal(c, &s) + if err == nil { + return Contents{{Type: ContentTypeText, Text: CutLast(s, opai.sep)}} + } + return nil } -func (opai *OpenAI) OutputRaw() string { +func (opai *OpenAI) OutputRaw() Contents { if len(opai.Choices) == 0 { - return "" + return nil } - return opai.Choices[len(opai.Choices)-1].Message.Content + cs := make(Contents, 0, 8) + c := opai.Choices[len(opai.Choices)-1].Message.Content + err := json.Unmarshal(c, &cs) + if err == nil { + return cs + } + s := "" + err = json.Unmarshal(c, &s) + if err == nil { + return Contents{{Type: ContentTypeText, Text: s}} + } + return nil } -func (opai *OpenAI) System(prompt string) Protocol { - opai.Messages = make([]Message, 1, 8) - opai.Messages[0] = Message{ +func (opai *OpenAI) System(prompt ...Content) Protocol { + opai.Messages = make([]OpenAIMessage, 1, 8) + raw, err := json.Marshal(&prompt) + if err != nil { + panic(err) + } + opai.Messages[0] = OpenAIMessage{ Role: "system", - Content: prompt, + Content: raw, } return opai } -func (opai *OpenAI) User(prompt string) Protocol { - opai.Messages = append(opai.Messages, Message{ +func (opai *OpenAI) User(prompt ...Content) Protocol { + raw, err := json.Marshal(&prompt) + if err != nil { + panic(err) + } + opai.Messages = append(opai.Messages, OpenAIMessage{ Role: "user", - Content: prompt, + Content: raw, }) return opai } -func (opai *OpenAI) Assistant(prompt string) Protocol { - opai.Messages = append(opai.Messages, Message{ +func (opai *OpenAI) Assistant(prompt ...Content) Protocol { + raw, err := json.Marshal(&prompt) + if err != nil { + panic(err) + } + opai.Messages = append(opai.Messages, OpenAIMessage{ Role: "assistant", - Content: prompt, + Content: raw, }) return opai }