diff --git a/model/openai.go b/model/openai.go index fc09bc1..121d9ae 100644 --- a/model/openai.go +++ b/model/openai.go @@ -128,26 +128,30 @@ func (opai *OpenAI) System(prompt string) Protocol { return opai } -func (opai *OpenAI) User(prompt ...Content) Protocol { - raw, err := json.Marshal(&prompt) +func (opai *OpenAI) normal(role string, prompt ...Content) Protocol { + var ( + raw json.RawMessage + err error + ) + if len(prompt) == 1 && prompt[0].Type == ContentTypeText { + raw, err = json.Marshal(&prompt[0].Text) + } else { + raw, err = json.Marshal(&prompt) + } if err != nil { panic(err) } opai.Messages = append(opai.Messages, OpenAIMessage{ - Role: "user", + Role: role, Content: raw, }) return opai } -func (opai *OpenAI) Assistant(prompt ...Content) Protocol { - raw, err := json.Marshal(&prompt) - if err != nil { - panic(err) - } - opai.Messages = append(opai.Messages, OpenAIMessage{ - Role: "assistant", - Content: raw, - }) - return opai +func (opai *OpenAI) User(prompt ...Content) Protocol { + return opai.normal("user", prompt...) +} + +func (opai *OpenAI) Assistant(prompt ...Content) Protocol { + return opai.normal("assistant", prompt...) }