Files
Tradarr/backend/internal/ai/ollama.go

96 lines
1.9 KiB
Go

package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
type ollamaProvider struct {
endpoint string
model string
client *http.Client
}
func newOllama(endpoint, model string) *ollamaProvider {
if endpoint == "" {
endpoint = "http://ollama:11434"
}
if model == "" {
model = "llama3"
}
return &ollamaProvider{
endpoint: endpoint,
model: model,
client: &http.Client{},
}
}
func (p *ollamaProvider) Name() string { return "ollama" }
func (p *ollamaProvider) Summarize(ctx context.Context, prompt string) (string, error) {
body := map[string]interface{}{
"model": p.model,
"prompt": prompt,
"stream": false,
"options": map[string]interface{}{
"num_ctx": 32768,
},
}
b, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.endpoint+"/api/generate", bytes.NewReader(b))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := p.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("ollama API error %d: %s", resp.StatusCode, raw)
}
var result struct {
Response string `json:"response"`
}
if err := json.Unmarshal(raw, &result); err != nil {
return "", err
}
return result.Response, nil
}
func (p *ollamaProvider) ListModels(ctx context.Context) ([]string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.endpoint+"/api/tags", nil)
if err != nil {
return nil, err
}
resp, err := p.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, _ := io.ReadAll(resp.Body)
var result struct {
Models []struct {
Name string `json:"name"`
} `json:"models"`
}
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
var models []string
for _, m := range result.Models {
models = append(models, m.Name)
}
return models, nil
}