53 lines
1.1 KiB
Go
53 lines
1.1 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type openAIProvider struct {
|
|
client *openai.Client
|
|
model string
|
|
}
|
|
|
|
func newOpenAI(apiKey, model string) *openAIProvider {
|
|
if model == "" {
|
|
model = openai.GPT4oMini
|
|
}
|
|
return &openAIProvider{
|
|
client: openai.NewClient(apiKey),
|
|
model: model,
|
|
}
|
|
}
|
|
|
|
func (p *openAIProvider) Name() string { return "openai" }
|
|
|
|
func (p *openAIProvider) Summarize(ctx context.Context, prompt string, _ GenOptions) (string, error) {
|
|
resp, err := p.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
|
Model: p.model,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{Role: openai.ChatMessageRoleUser, Content: prompt},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", nil
|
|
}
|
|
return resp.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
func (p *openAIProvider) ListModels(ctx context.Context) ([]string, error) {
|
|
resp, err := p.client.ListModels(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var models []string
|
|
for _, m := range resp.Models {
|
|
models = append(models, m.ID)
|
|
}
|
|
return models, nil
|
|
}
|