feat: add download features to llm models
This commit is contained in:
@ -7,31 +7,51 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ollamaProvider struct {
|
||||
// OllamaModelInfo holds detailed info about an installed Ollama model.
|
||||
type OllamaModelInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Details struct {
|
||||
ParameterSize string `json:"parameter_size"`
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
Family string `json:"family"`
|
||||
} `json:"details"`
|
||||
}
|
||||
|
||||
// OllamaProvider implements Provider for Ollama and also exposes model management operations.
|
||||
type OllamaProvider struct {
|
||||
endpoint string
|
||||
model string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func newOllama(endpoint, model string) *ollamaProvider {
|
||||
func newOllama(endpoint, model string) *OllamaProvider {
|
||||
if endpoint == "" {
|
||||
endpoint = "http://ollama:11434"
|
||||
}
|
||||
if model == "" {
|
||||
model = "llama3"
|
||||
}
|
||||
return &ollamaProvider{
|
||||
return &OllamaProvider{
|
||||
endpoint: endpoint,
|
||||
model: model,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ollamaProvider) Name() string { return "ollama" }
|
||||
// NewOllamaManager creates an OllamaProvider for model management (pull/delete/list).
|
||||
func NewOllamaManager(endpoint string) *OllamaProvider {
|
||||
return newOllama(endpoint, "")
|
||||
}
|
||||
|
||||
func (p *ollamaProvider) Summarize(ctx context.Context, prompt string, opts GenOptions) (string, error) {
|
||||
|
||||
func (p *OllamaProvider) Name() string { return "ollama" }
|
||||
|
||||
func (p *OllamaProvider) Summarize(ctx context.Context, prompt string, opts GenOptions) (string, error) {
|
||||
numCtx := 32768
|
||||
if opts.NumCtx > 0 {
|
||||
numCtx = opts.NumCtx
|
||||
@ -72,7 +92,19 @@ func (p *ollamaProvider) Summarize(ctx context.Context, prompt string, opts GenO
|
||||
return result.Response, nil
|
||||
}
|
||||
|
||||
func (p *ollamaProvider) ListModels(ctx context.Context) ([]string, error) {
|
||||
func (p *OllamaProvider) ListModels(ctx context.Context) ([]string, error) {
|
||||
infos, err := p.ListModelsDetailed(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, len(infos))
|
||||
for i, m := range infos {
|
||||
names[i] = m.Name
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ListModelsDetailed(ctx context.Context) ([]OllamaModelInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.endpoint+"/api/tags", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -85,16 +117,52 @@ func (p *ollamaProvider) ListModels(ctx context.Context) ([]string, error) {
|
||||
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
var result struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"models"`
|
||||
Models []OllamaModelInfo `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var models []string
|
||||
for _, m := range result.Models {
|
||||
models = append(models, m.Name)
|
||||
}
|
||||
return models, nil
|
||||
return result.Models, nil
|
||||
}
|
||||
|
||||
// PullModel pulls (downloads) a model from Ollama Hub. Blocks until complete.
|
||||
func (p *OllamaProvider) PullModel(ctx context.Context, name string) error {
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": name, "stream": false})
|
||||
// Use a long-timeout client since model downloads can take many minutes
|
||||
client := &http.Client{Timeout: 60 * time.Minute}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.endpoint+"/api/pull", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("ollama pull error %d: %s", resp.StatusCode, raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteModel removes a model from local storage.
|
||||
func (p *OllamaProvider) DeleteModel(ctx context.Context, name string) error {
|
||||
body, _ := json.Marshal(map[string]string{"name": name})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, p.endpoint+"/api/delete", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("ollama delete error %d: %s", resp.StatusCode, raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -50,28 +50,36 @@ func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error
|
||||
return NewProvider(name, apiKey, model, endpoint)
|
||||
}
|
||||
|
||||
// buildProviderForRole resolves and builds the AI provider for a given task role.
|
||||
func (p *Pipeline) buildProviderForRole(role string) (Provider, *models.AIProvider, error) {
|
||||
cfg, model, err := p.repo.GetRoleProvider(role)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get provider for role %s: %w", role, err)
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, nil, fmt.Errorf("no AI provider configured for role %s", role)
|
||||
}
|
||||
apiKey := ""
|
||||
if cfg.APIKeyEncrypted != "" {
|
||||
apiKey, err = p.enc.Decrypt(cfg.APIKeyEncrypted)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("decrypt API key for role %s: %w", role, err)
|
||||
}
|
||||
}
|
||||
provider, err := NewProvider(cfg.Name, apiKey, model, cfg.Endpoint)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("build provider for role %s: %w", role, err)
|
||||
}
|
||||
return provider, cfg, nil
|
||||
}
|
||||
|
||||
func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.Summary, error) {
|
||||
p.generating.Store(true)
|
||||
defer p.generating.Store(false)
|
||||
providerCfg, err := p.repo.GetActiveAIProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get active provider: %w", err)
|
||||
}
|
||||
if providerCfg == nil {
|
||||
return nil, fmt.Errorf("no active AI provider configured")
|
||||
}
|
||||
|
||||
apiKey := ""
|
||||
if providerCfg.APIKeyEncrypted != "" {
|
||||
apiKey, err = p.enc.Decrypt(providerCfg.APIKeyEncrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt API key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := NewProvider(providerCfg.Name, apiKey, providerCfg.Model, providerCfg.Endpoint)
|
||||
provider, providerCfg, err := p.buildProviderForRole("summary")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build provider: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
assets, err := p.repo.GetUserAssets(userID)
|
||||
@ -105,9 +113,13 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
||||
|
||||
// Passe 1 : filtrage par pertinence — seulement si nettement plus d'articles que le max
|
||||
if len(articles) > maxArticles*2 {
|
||||
filterProvider, _, filterErr := p.buildProviderForRole("filter")
|
||||
if filterErr != nil {
|
||||
filterProvider = provider // fallback to summary provider
|
||||
}
|
||||
fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles)
|
||||
t1 := time.Now()
|
||||
articles = p.filterByRelevance(ctx, provider, symbols, articles, maxArticles)
|
||||
articles = p.filterByRelevance(ctx, filterProvider, symbols, articles, maxArticles)
|
||||
fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles))
|
||||
} else if len(articles) > maxArticles {
|
||||
articles = articles[:maxArticles]
|
||||
@ -135,49 +147,84 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
||||
return p.repo.CreateSummary(userID, summary, &providerCfg.ID)
|
||||
}
|
||||
|
||||
// filterByRelevance demande à l'IA de sélectionner les articles les plus pertinents
|
||||
// en ne lui envoyant que les titres (prompt très court = rapide).
|
||||
// filterByRelevance splits articles into batches and asks the AI to select relevant
|
||||
// ones from each batch. Results are pooled then truncated to max.
|
||||
func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, symbols []string, articles []models.Article, max int) []models.Article {
|
||||
prompt := buildFilterPrompt(symbols, articles, max)
|
||||
// Passe 1 : pas de think, contexte réduit (titres seulement = prompt court)
|
||||
response, err := provider.Summarize(ctx, prompt, GenOptions{Think: false, NumCtx: 8192})
|
||||
if err != nil {
|
||||
fmt.Printf("[pipeline] Passe 1 — échec (%v), repli sur troncature\n", err)
|
||||
return articles[:max]
|
||||
batchSizeStr, _ := p.repo.GetSetting("filter_batch_size")
|
||||
batchSize, _ := strconv.Atoi(batchSizeStr)
|
||||
if batchSize <= 0 {
|
||||
batchSize = 20
|
||||
}
|
||||
|
||||
indices := parseIndexArray(response, len(articles))
|
||||
var selected []models.Article
|
||||
numBatches := (len(articles) + batchSize - 1) / batchSize
|
||||
|
||||
for b := 0; b < numBatches; b++ {
|
||||
start := b * batchSize
|
||||
end := start + batchSize
|
||||
if end > len(articles) {
|
||||
end = len(articles)
|
||||
}
|
||||
batch := articles[start:end]
|
||||
|
||||
fmt.Printf("[pipeline] Passe 1 — batch %d/%d (%d articles)…\n", b+1, numBatches, len(batch))
|
||||
t := time.Now()
|
||||
chosen := p.filterBatch(ctx, provider, symbols, batch)
|
||||
fmt.Printf("[pipeline] Passe 1 — batch %d/%d terminé en %s : %d retenus\n", b+1, numBatches, time.Since(t).Round(time.Second), len(chosen))
|
||||
|
||||
selected = append(selected, chosen...)
|
||||
|
||||
// Stop early if we have plenty of candidates
|
||||
if len(selected) >= max*2 {
|
||||
fmt.Printf("[pipeline] Passe 1 — suffisamment de candidats (%d), arrêt anticipé\n", len(selected))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(selected) <= max {
|
||||
return selected
|
||||
}
|
||||
return selected[:max]
|
||||
}
|
||||
|
||||
// filterBatch asks the AI to return all relevant articles from a single batch.
|
||||
func (p *Pipeline) filterBatch(ctx context.Context, provider Provider, symbols []string, batch []models.Article) []models.Article {
|
||||
prompt := buildFilterBatchPrompt(symbols, batch)
|
||||
response, err := provider.Summarize(ctx, prompt, GenOptions{Think: false, NumCtx: 4096})
|
||||
if err != nil {
|
||||
fmt.Printf("[pipeline] filterBatch — échec (%v), conservation du batch entier\n", err)
|
||||
return batch
|
||||
}
|
||||
|
||||
indices := parseIndexArray(response, len(batch))
|
||||
if len(indices) == 0 {
|
||||
fmt.Printf("[pipeline] Passe 1 — réponse non parseable, repli sur troncature\n")
|
||||
return articles[:max]
|
||||
return nil
|
||||
}
|
||||
|
||||
filtered := make([]models.Article, 0, len(indices))
|
||||
for _, i := range indices {
|
||||
filtered = append(filtered, articles[i])
|
||||
if len(filtered) >= max {
|
||||
break
|
||||
}
|
||||
filtered = append(filtered, batch[i])
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func buildFilterPrompt(symbols []string, articles []models.Article, max int) string {
|
||||
func buildFilterBatchPrompt(symbols []string, batch []models.Article) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Tu es un assistant de trading financier. ")
|
||||
sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne les %d plus pertinents pour un trader actif.\n", len(articles), max))
|
||||
sb.WriteString("Tu es un assistant de trading financier.\n")
|
||||
sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne TOUS ceux pertinents pour un trader actif.\n", len(batch)))
|
||||
|
||||
if len(symbols) > 0 {
|
||||
sb.WriteString("Actifs surveillés (priorité haute) : ")
|
||||
sb.WriteString("Actifs prioritaires : ")
|
||||
sb.WriteString(strings.Join(symbols, ", "))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("\nRéponds UNIQUEMENT avec un tableau JSON des indices sélectionnés (base 0), exemple : [0, 3, 7, 12]\n"))
|
||||
sb.WriteString("\nRéponds UNIQUEMENT avec un tableau JSON des indices retenus (base 0), exemple : [0, 2, 5]\n")
|
||||
sb.WriteString("Si aucun article n'est pertinent, réponds : []\n")
|
||||
sb.WriteString("N'ajoute aucun texte avant ou après le tableau JSON.\n\n")
|
||||
sb.WriteString("Articles :\n")
|
||||
|
||||
for i, a := range articles {
|
||||
for i, a := range batch {
|
||||
sb.WriteString(fmt.Sprintf("[%d] %s (%s)\n", i, a.Title, a.SourceName))
|
||||
}
|
||||
|
||||
@ -243,25 +290,9 @@ func (p *Pipeline) GenerateReportAsync(reportID, excerpt, question string, mgr *
|
||||
}
|
||||
|
||||
func (p *Pipeline) callProviderForReport(ctx context.Context, excerpt, question string) (string, error) {
|
||||
providerCfg, err := p.repo.GetActiveAIProvider()
|
||||
provider, _, err := p.buildProviderForRole("report")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get active provider: %w", err)
|
||||
}
|
||||
if providerCfg == nil {
|
||||
return "", fmt.Errorf("no active AI provider configured")
|
||||
}
|
||||
|
||||
apiKey := ""
|
||||
if providerCfg.APIKeyEncrypted != "" {
|
||||
apiKey, err = p.enc.Decrypt(providerCfg.APIKeyEncrypted)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt API key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := NewProvider(providerCfg.Name, apiKey, providerCfg.Model, providerCfg.Endpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build provider: %w", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(
|
||||
|
||||
Reference in New Issue
Block a user