From 12ba4605888504f001e1677cb359ff8268e9d070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 14 Feb 2025 16:29:37 +0900 Subject: [PATCH] feat: add custom --- README.md | 14 ++++++++ model.go | 8 ++--- model/api.go | 33 ++++++++++++++++++ model/custom.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++ model/deepseek.go | 34 +++++++----------- model/utils.go | 3 ++ 6 files changed, 155 insertions(+), 26 deletions(-) create mode 100644 model/api.go create mode 100644 model/custom.go diff --git a/README.md b/README.md index 3e60126..8221ad2 100644 --- a/README.md +++ b/README.md @@ -13,3 +13,17 @@ if err != nil { fmt.Println(txt) // Hello! How can I assist you today? ``` + +## Custom Call +```go +api := NewAPI(APIDeepInfra, "PUT YOUR API KEY HERE") +txt, err := api.Request(model.NewCustom("fumiama/ninus", "", 0.7, 0.9, 1024). + System("你正在QQ群与用户聊天,用户发送了消息。按自己的心情简短思考后,条理清晰地回应**一句话**,禁止回应多句。"). + User("总不能什么都查吧").User("后面DOGE就成恶龙了 很常见的场景"), +) +if err != nil { + panic(err) +} +fmt.Println(txt) +// 要不我给你查一下? +``` diff --git a/model.go b/model.go index 77e3af9..35d13f8 100644 --- a/model.go +++ b/model.go @@ -1,12 +1,10 @@ package deepinfra import ( - "bytes" - "io" + "github.com/fumiama/deepinfra/model" ) type Model interface { - Body() *bytes.Buffer - Parse(io.Reader) error - Output() string + model.Inputer + model.Outputer } diff --git a/model/api.go b/model/api.go new file mode 100644 index 0000000..883fdbf --- /dev/null +++ b/model/api.go @@ -0,0 +1,33 @@ +package model + +import ( + "bytes" + "io" +) + +type Inputer interface { + Body() *bytes.Buffer + Parse(io.Reader) error +} + +type Outputer interface { + Output() string + OutputRaw() string +} + +type MessageBuilder[T any] interface { + System(prompt string) T + User(prompt string) T + Assistant(prompt string) T +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Choice struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} diff --git a/model/custom.go b/model/custom.go new file mode 100644 index 0000000..7fa45a4 --- /dev/null +++ b/model/custom.go @@ -0,0 +1,89 @@ +package model + +import ( + "bytes" + "encoding/json" + "io" +) + +// Custom as an compatible example. +type Custom struct { + Inputer + Outputer + MessageBuilder[*DeepSeek] + sep string + // callback only + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Choices []Choice `json:"choices,omitempty"` + // callback/request + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature float32 `json:"temperature"` // Temperature 0.7 + TopP float32 `json:"top_p"` // TopP 0.9 + MaxTokens int `json:"max_tokens"` // MaxTokens 16384 + +} + +func NewCustom(model, sep string, temp, topp float32, maxn uint) *Custom { + c := new(Custom) + c.sep = sep + c.Model = model + c.Temperature = temp + c.TopP = topp + c.MaxTokens = int(maxn) + return c +} + +func (c *Custom) Parse(body io.Reader) error { + return json.NewDecoder(body).Decode(&c) +} + +func (c *Custom) Output() string { + if len(c.Choices) == 0 { + return "" + } + return CutLast(c.Choices[len(c.Choices)-1].Message.Content, c.sep) +} + +func (c *Custom) OutputRaw() string { + if len(c.Choices) == 0 { + return "" + } + return c.Choices[len(c.Choices)-1].Message.Content +} + +func (ds *Custom) System(prompt string) *Custom { + ds.Messages = make([]Message, 1, 8) + ds.Messages[0] = Message{ + Role: "system", + Content: prompt, + } + return ds +} + +func (ds *Custom) User(prompt string) *Custom { + ds.Messages = append(ds.Messages, Message{ + Role: "user", + Content: prompt, + }) + return ds +} + +func (ds *Custom) Assistant(prompt string) *Custom { + ds.Messages = append(ds.Messages, Message{ + Role: "assistant", + Content: prompt, + }) + return ds +} + +func (ds *Custom) Body() *bytes.Buffer { + w := bytes.NewBuffer(make([]byte, 0, 16384)) + err := json.NewEncoder(w).Encode(ds) + if err != nil { + panic(err) + } + return w +} diff --git a/model/deepseek.go b/model/deepseek.go index c584767..ae282ef 100644 --- a/model/deepseek.go +++ b/model/deepseek.go @@ -10,8 +10,11 @@ const ( modelDeepDeek = "deepseek-ai/DeepSeek-R1" ) -// DeepSeek as an example. +// DeepSeek as an specified example. type DeepSeek struct { + Inputer + Outputer + MessageBuilder[*DeepSeek] // callback only ID string `json:"id,omitempty"` Object string `json:"object,omitempty"` @@ -26,17 +29,6 @@ type DeepSeek struct { } -type Message struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type Choice struct { - Index int `json:"index"` - Message Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - // NewDeepSeek 0.7, 0.9 func NewDeepSeek(temp, topp float32, maxn uint) *DeepSeek { ds := new(DeepSeek) @@ -47,6 +39,15 @@ func NewDeepSeek(temp, topp float32, maxn uint) *DeepSeek { return ds } +func (ds *DeepSeek) Body() *bytes.Buffer { + w := bytes.NewBuffer(make([]byte, 0, 16384)) + err := json.NewEncoder(w).Encode(ds) + if err != nil { + panic(err) + } + return w +} + func (ds *DeepSeek) Parse(body io.Reader) error { return json.NewDecoder(body).Decode(&ds) } @@ -89,12 +90,3 @@ func (ds *DeepSeek) Assistant(prompt string) *DeepSeek { }) return ds } - -func (ds *DeepSeek) Body() *bytes.Buffer { - w := bytes.NewBuffer(make([]byte, 0, 16384)) - err := json.NewEncoder(w).Encode(ds) - if err != nil { - panic(err) - } - return w -} diff --git a/model/utils.go b/model/utils.go index 3e8d5e7..688a143 100644 --- a/model/utils.go +++ b/model/utils.go @@ -7,6 +7,9 @@ const ( ) func CutLast(txt, sep string) string { + if sep == "" { // no need to cut + return txt + } a := strings.LastIndex(txt, sep) if a < 0 { return ""