mirror of
https://github.com/fumiama/deepinfra.git
synced 2026-06-11 21:46:39 +08:00
feat: support image upload
This commit is contained in:
10
model/api.go
10
model/api.go
@@ -12,8 +12,8 @@ type Inputer interface {
|
||||
}
|
||||
|
||||
type Outputer interface {
|
||||
Output() string
|
||||
OutputRaw() string
|
||||
Output() Contents
|
||||
OutputRaw() Contents
|
||||
}
|
||||
|
||||
type Requester interface {
|
||||
@@ -22,9 +22,9 @@ type Requester interface {
|
||||
}
|
||||
|
||||
type MessageBuilder[T any] interface {
|
||||
System(prompt string) T
|
||||
User(prompt string) T
|
||||
Assistant(prompt string) T
|
||||
System(prompt ...Content) T
|
||||
User(prompt ...Content) T
|
||||
Assistant(prompt ...Content) T
|
||||
}
|
||||
|
||||
type Protocol interface {
|
||||
|
||||
93
model/content.go
Normal file
93
model/content.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"mime"
|
||||
"strings"
|
||||
|
||||
"github.com/fumiama/imgsz"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedImageExtension = errors.New("unsupported image extension")
|
||||
)
|
||||
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeText ContentType = "text"
|
||||
ContentTypeImageURL ContentType = "image_url"
|
||||
)
|
||||
|
||||
type Contents []Content
|
||||
|
||||
func (cs Contents) String() string {
|
||||
switch len(cs) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
if cs[0].Type == ContentTypeText {
|
||||
return cs[0].Text
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
sb := strings.Builder{}
|
||||
for _, c := range cs {
|
||||
switch c.Type {
|
||||
case ContentTypeText:
|
||||
sb.WriteString("<text>")
|
||||
sb.WriteString(c.Text)
|
||||
sb.WriteString("</text>")
|
||||
case ContentTypeImageURL:
|
||||
if c.ImageURL != nil {
|
||||
sb.WriteString("<image_url>")
|
||||
sb.WriteString(c.ImageURL.URL)
|
||||
sb.WriteString("</image_url>")
|
||||
}
|
||||
default:
|
||||
panic("unsupported ContentType " + c.Type)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Type ContentType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ContentImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func NewContentText(txt string) Content {
|
||||
return Content{Type: ContentTypeText, Text: txt}
|
||||
}
|
||||
|
||||
type ContentImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func NewContentImageURL(u string) Content {
|
||||
return Content{Type: ContentTypeImageURL, ImageURL: &ContentImageURL{URL: u}}
|
||||
}
|
||||
|
||||
func NewContentImageDataBase64URL(data []byte) (string, error) {
|
||||
_, ext, err := imgsz.DecodeSize(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
typ := mime.TypeByExtension("." + ext)
|
||||
if len(typ) == 0 {
|
||||
return "", errors.Wrap(ErrUnsupportedImageExtension, ext)
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString("data:")
|
||||
sb.WriteString(typ)
|
||||
sb.WriteString(";base64,")
|
||||
enc := base64.NewEncoder(base64.StdEncoding, &sb)
|
||||
_, _ = io.Copy(enc, bytes.NewReader(data))
|
||||
_ = enc.Close()
|
||||
return sb.String(), nil
|
||||
}
|
||||
33
model/content_test.go
Normal file
33
model/content_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const contentTextSmallImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
|
||||
|
||||
func TestContentImageDataBytes(t *testing.T) {
|
||||
// 从 smallimg 中提取 base64 数据
|
||||
parts := strings.Split(contentTextSmallImage, ",")
|
||||
if len(parts) != 2 {
|
||||
t.Fatal("Invalid data URL format")
|
||||
}
|
||||
|
||||
// 解码 base64 数据
|
||||
data, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
s, err := NewContentImageDataBase64URL(data)
|
||||
if err != nil {
|
||||
t.Fatalf("NewContentImageDataBytes failed: %v", err)
|
||||
}
|
||||
|
||||
// 比较结果
|
||||
if s != contentTextSmallImage {
|
||||
t.Errorf("Expected %s, got %s", contentTextSmallImage, s)
|
||||
}
|
||||
}
|
||||
124
model/genai.go
124
model/genai.go
@@ -13,16 +13,22 @@ const (
|
||||
ModelGemini15Flash = "models/gemini-1.5-flash"
|
||||
)
|
||||
|
||||
type Text struct {
|
||||
Text string `json:"text"`
|
||||
type GenAIInlineData struct {
|
||||
MimeType string `json:"mime_type"`
|
||||
Data string `json:"data"` // Data is base64 repr
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Parts []Text `json:"parts"`
|
||||
Role string `json:"role,omitempty"`
|
||||
type GenAIPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GenAIInlineData `json:"inline_data,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Content) String() string {
|
||||
type GenAIContent struct {
|
||||
Parts []GenAIPart `json:"parts"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
func (c *GenAIContent) String() string {
|
||||
sb := strings.Builder{}
|
||||
for _, p := range c.Parts {
|
||||
sb.WriteString(p.Text)
|
||||
@@ -30,10 +36,10 @@ func (c *Content) String() string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
type GenAICandidate struct {
|
||||
Content GenAIContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// GenAI is Goole API format
|
||||
@@ -41,8 +47,8 @@ type GenAI struct {
|
||||
model string `json:"-"`
|
||||
Protocol `json:"-"`
|
||||
// request only
|
||||
Contents []Content `json:"contents,omitempty"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
Contents []GenAIContent `json:"contents,omitempty"`
|
||||
SystemInstruction *GenAIContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
@@ -50,7 +56,7 @@ type GenAI struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
} `json:"generationConfig"`
|
||||
// callback only
|
||||
Candidates []Candidate `json:"candidates,omitempty"`
|
||||
Candidates []GenAICandidate `json:"candidates,omitempty"`
|
||||
}
|
||||
|
||||
// NewGenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -85,35 +91,99 @@ func (opai *GenAI) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&opai)
|
||||
}
|
||||
|
||||
func (opai *GenAI) Output() string {
|
||||
func (opai *GenAI) Output() Contents {
|
||||
return opai.OutputRaw()
|
||||
}
|
||||
|
||||
func (opai *GenAI) OutputRaw() Contents {
|
||||
if len(opai.Candidates) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return opai.Candidates[0].Content.String()
|
||||
raw := opai.Candidates[0].Content
|
||||
cs := make(Contents, len(raw.Parts))
|
||||
for i, c := range raw.Parts {
|
||||
switch {
|
||||
case c.Text != "":
|
||||
cs[i].Type = ContentTypeText
|
||||
cs[i].Text = c.Text
|
||||
case c.InlineData != nil:
|
||||
cs[i].Type = ContentTypeImageURL
|
||||
if strings.HasPrefix(c.InlineData.MimeType, "image/") {
|
||||
cs[i].ImageURL = &ContentImageURL{
|
||||
URL: "data:" + c.InlineData.MimeType + ";base64," + c.InlineData.Data,
|
||||
}
|
||||
}
|
||||
default:
|
||||
panic("unsupported genai part")
|
||||
}
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
func (opai *GenAI) OutputRaw() string {
|
||||
return opai.Output()
|
||||
func (cs Contents) ToGenAIParts() []GenAIPart {
|
||||
ps := make([]GenAIPart, 0, len(cs))
|
||||
for _, c := range cs {
|
||||
switch c.Type {
|
||||
case ContentTypeText:
|
||||
ps = append(ps, GenAIPart{Text: c.Text})
|
||||
case ContentTypeImageURL:
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:") {
|
||||
typ, dat, ok := strings.Cut(strings.TrimPrefix(c.ImageURL.URL, "data:"), ";")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
|
||||
MimeType: typ,
|
||||
Data: dat[1:], // skip ;
|
||||
}})
|
||||
continue
|
||||
}
|
||||
resp, err := http.Get(c.ImageURL.URL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s, err := NewContentImageDataBase64URL(data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
typ, dat, ok := strings.Cut(strings.TrimPrefix(s, "data:"), ",")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ps = append(ps, GenAIPart{InlineData: &GenAIInlineData{
|
||||
MimeType: typ,
|
||||
Data: dat[1:], // skip ,
|
||||
}})
|
||||
default:
|
||||
panic("unsupported ContentType " + c.Type)
|
||||
}
|
||||
}
|
||||
return ps
|
||||
}
|
||||
|
||||
func (opai *GenAI) System(prompt string) Protocol {
|
||||
opai.SystemInstruction = &Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) System(prompt ...Content) Protocol {
|
||||
opai.SystemInstruction = &GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
}
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *GenAI) User(prompt string) Protocol {
|
||||
opai.Contents = append(opai.Contents, Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) User(prompt ...Content) Protocol {
|
||||
opai.Contents = append(opai.Contents, GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
Role: "user",
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *GenAI) Assistant(prompt string) Protocol {
|
||||
opai.Contents = append(opai.Contents, Content{
|
||||
Parts: []Text{{prompt}},
|
||||
func (opai *GenAI) Assistant(prompt ...Content) Protocol {
|
||||
opai.Contents = append(opai.Contents, GenAIContent{
|
||||
Parts: Contents(prompt).ToGenAIParts(),
|
||||
Role: "model",
|
||||
})
|
||||
return opai
|
||||
|
||||
@@ -7,22 +7,27 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OLLaMAMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// OLLaMA as an specified example.
|
||||
type OLLaMA struct {
|
||||
sep string
|
||||
Protocol `json:"-"`
|
||||
// callback only
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Messages []Message `json:"messages"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Messages []OLLaMAMessage `json:"messages"`
|
||||
// callback/request
|
||||
Model string `json:"model"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Temperature float32 `json:"temperature"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens"` // MaxTokens 4096
|
||||
Stream bool `json:"stream"`
|
||||
Model string `json:"model"`
|
||||
Message *OLLaMAMessage `json:"message,omitempty"`
|
||||
Temperature float32 `json:"temperature"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens"` // MaxTokens 4096
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// NewOLLaMA use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -58,41 +63,41 @@ func (ollm *OLLaMA) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&ollm)
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) Output() string {
|
||||
func (ollm *OLLaMA) Output() Contents {
|
||||
if ollm.Message == nil {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return CutLast(ollm.Message.Content, ollm.sep)
|
||||
return Contents{NewContentText(CutLast(ollm.Message.Content, ollm.sep))}
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) OutputRaw() string {
|
||||
func (ollm *OLLaMA) OutputRaw() Contents {
|
||||
if ollm.Message == nil {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return ollm.Message.Content
|
||||
return Contents{NewContentText(ollm.Message.Content)}
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) System(prompt string) Protocol {
|
||||
ollm.Messages = make([]Message, 1, 8)
|
||||
ollm.Messages[0] = Message{
|
||||
func (ollm *OLLaMA) System(prompt ...Content) Protocol {
|
||||
ollm.Messages = make([]OLLaMAMessage, 1, 8)
|
||||
ollm.Messages[0] = OLLaMAMessage{
|
||||
Role: "system",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
}
|
||||
return ollm
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) User(prompt string) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, Message{
|
||||
func (ollm *OLLaMA) User(prompt ...Content) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, OLLaMAMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
})
|
||||
return ollm
|
||||
}
|
||||
|
||||
func (ollm *OLLaMA) Assistant(prompt string) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, Message{
|
||||
func (ollm *OLLaMA) Assistant(prompt ...Content) Protocol {
|
||||
ollm.Messages = append(ollm.Messages, OLLaMAMessage{
|
||||
Role: "assistant",
|
||||
Content: prompt,
|
||||
Content: prompt[0].Text,
|
||||
})
|
||||
return ollm
|
||||
}
|
||||
|
||||
107
model/openai.go
107
model/openai.go
@@ -8,18 +8,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ModelDeepDeek = "deepseek-ai/DeepSeek-R1"
|
||||
ModelDeepDeek = "deepseek-ai/DeepSeek-V3.1-Terminus"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
type OpenAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"` // Contents or string
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
type OpenAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message OpenAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// OpenAI as an specified example.
|
||||
@@ -27,18 +27,18 @@ type OpenAI struct {
|
||||
sep string
|
||||
Protocol `json:"-"`
|
||||
// callback only
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Choices []Choice `json:"choices,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Choices []OpenAIChoice `json:"choices,omitempty"`
|
||||
// callback/request
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p,omitempty"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
|
||||
Model string `json:"model"`
|
||||
Messages []OpenAIMessage `json:"messages"`
|
||||
Temperature float32 `json:"temperature,omitempty"` // Temperature 0.7
|
||||
TopP float32 `json:"top_p,omitempty"` // TopP 0.9
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // MaxTokens 4096
|
||||
// extra body
|
||||
EnableThinking bool `json:"enable_thinking"`
|
||||
EnableThinking bool `json:"enable_thinking"` // EnableThinking is always false in non-stream mode, adapt to 百炼
|
||||
}
|
||||
|
||||
// NewOpenAI use temp 0.7, topp 0.9, maxn 4096 if you don't know the meaning.
|
||||
@@ -74,41 +74,80 @@ func (opai *OpenAI) Parse(body io.Reader) error {
|
||||
return json.NewDecoder(body).Decode(&opai)
|
||||
}
|
||||
|
||||
func (opai *OpenAI) Output() string {
|
||||
func (opai *OpenAI) Output() Contents {
|
||||
if len(opai.Choices) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return CutLast(opai.Choices[len(opai.Choices)-1].Message.Content, opai.sep)
|
||||
cs := make(Contents, 0, 8)
|
||||
c := opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
err := json.Unmarshal(c, &cs)
|
||||
if err == nil {
|
||||
for i := range cs {
|
||||
if cs[i].Type == ContentTypeText {
|
||||
cs[i].Text = CutLast(cs[i].Text, opai.sep)
|
||||
}
|
||||
}
|
||||
return cs
|
||||
}
|
||||
s := ""
|
||||
err = json.Unmarshal(c, &s)
|
||||
if err == nil {
|
||||
return Contents{{Type: ContentTypeText, Text: CutLast(s, opai.sep)}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opai *OpenAI) OutputRaw() string {
|
||||
func (opai *OpenAI) OutputRaw() Contents {
|
||||
if len(opai.Choices) == 0 {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
return opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
cs := make(Contents, 0, 8)
|
||||
c := opai.Choices[len(opai.Choices)-1].Message.Content
|
||||
err := json.Unmarshal(c, &cs)
|
||||
if err == nil {
|
||||
return cs
|
||||
}
|
||||
s := ""
|
||||
err = json.Unmarshal(c, &s)
|
||||
if err == nil {
|
||||
return Contents{{Type: ContentTypeText, Text: s}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opai *OpenAI) System(prompt string) Protocol {
|
||||
opai.Messages = make([]Message, 1, 8)
|
||||
opai.Messages[0] = Message{
|
||||
func (opai *OpenAI) System(prompt ...Content) Protocol {
|
||||
opai.Messages = make([]OpenAIMessage, 1, 8)
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages[0] = OpenAIMessage{
|
||||
Role: "system",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
}
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *OpenAI) User(prompt string) Protocol {
|
||||
opai.Messages = append(opai.Messages, Message{
|
||||
func (opai *OpenAI) User(prompt ...Content) Protocol {
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages = append(opai.Messages, OpenAIMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
func (opai *OpenAI) Assistant(prompt string) Protocol {
|
||||
opai.Messages = append(opai.Messages, Message{
|
||||
func (opai *OpenAI) Assistant(prompt ...Content) Protocol {
|
||||
raw, err := json.Marshal(&prompt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opai.Messages = append(opai.Messages, OpenAIMessage{
|
||||
Role: "assistant",
|
||||
Content: prompt,
|
||||
Content: raw,
|
||||
})
|
||||
return opai
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user