package ai import ( "context" "fmt" "regexp" "strconv" "strings" "sync/atomic" "time" "github.com/tradarr/backend/internal/crypto" "github.com/tradarr/backend/internal/models" ) const DefaultSystemPrompt = `Tu es un assistant spécialisé en trading financier. Analyse l'ensemble des actualités suivantes, toutes sources confondues, et crée un résumé global structuré en français, orienté trading. Structure ton résumé ainsi : 1. **Vue macro** : tendances globales du marché (économie, géopolitique, secteurs) 2. **Actifs surveillés** : pour chaque actif de la watchlist mentionné dans les news : - Sentiment (haussier/baissier/neutre) - Faits clés et catalyseurs - Risques et opportunités 3. **Autres mouvements notables** : actifs hors watchlist à surveiller 4. **Synthèse** : points d'attention prioritaires pour la journée` type Pipeline struct { repo *models.Repository enc *crypto.Encryptor generating atomic.Bool } func NewPipeline(repo *models.Repository, enc *crypto.Encryptor) *Pipeline { return &Pipeline{repo: repo, enc: enc} } func (p *Pipeline) IsGenerating() bool { return p.generating.Load() } func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error) { return NewProvider(name, apiKey, "", endpoint) } // buildProviderForRole resolves and builds the AI provider for a given task role. func (p *Pipeline) buildProviderForRole(role string) (Provider, *models.AIProvider, error) { cfg, model, err := p.repo.GetRoleProvider(role) if err != nil { return nil, nil, fmt.Errorf("get provider for role %s: %w", role, err) } if cfg == nil { return nil, nil, fmt.Errorf("no AI provider configured for role %s", role) } apiKey := "" if cfg.APIKeyEncrypted != "" { apiKey, err = p.enc.Decrypt(cfg.APIKeyEncrypted) if err != nil { return nil, nil, fmt.Errorf("decrypt API key for role %s: %w", role, err) } } provider, err := NewProvider(cfg.Name, apiKey, model, cfg.Endpoint) if err != nil { return nil, nil, fmt.Errorf("build provider for role %s: %w", role, err) } return provider, cfg, nil } func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.Summary, error) { p.generating.Store(true) defer p.generating.Store(false) provider, providerCfg, err := p.buildProviderForRole("summary") if err != nil { return nil, err } assets, err := p.repo.GetUserAssets(userID) if err != nil { return nil, fmt.Errorf("get user assets: %w", err) } symbols := make([]string, len(assets)) for i, a := range assets { symbols[i] = a.Symbol } hoursStr, _ := p.repo.GetSetting("articles_lookback_hours") hours, _ := strconv.Atoi(hoursStr) if hours == 0 { hours = 24 } articles, err := p.repo.GetRecentArticles(hours) if err != nil { return nil, fmt.Errorf("get articles: %w", err) } if len(articles) == 0 { return nil, fmt.Errorf("no recent articles found") } maxStr, _ := p.repo.GetSetting("summary_max_articles") maxArticles, _ := strconv.Atoi(maxStr) if maxArticles == 0 { maxArticles = 50 } // Passe 1 : filtrage par pertinence — seulement si nettement plus d'articles que le max if len(articles) > maxArticles*2 { filterProvider, _, filterErr := p.buildProviderForRole("filter") if filterErr != nil { filterProvider = provider // fallback to summary provider } fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles) t1 := time.Now() articles = p.filterByRelevance(ctx, filterProvider, symbols, articles, maxArticles) fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles)) } else if len(articles) > maxArticles { articles = articles[:maxArticles] fmt.Printf("[pipeline] troncature directe à %d articles (pas assez d'excédent pour justifier un appel IA)\n", maxArticles) } systemPrompt, _ := p.repo.GetSetting("ai_system_prompt") if systemPrompt == "" { systemPrompt = DefaultSystemPrompt } tz, _ := p.repo.GetSetting("timezone") // Passe 2 : résumé complet fmt.Printf("[pipeline] Passe 2 — résumé : génération sur %d articles…\n", len(articles)) t2 := time.Now() prompt := buildPrompt(systemPrompt, symbols, articles, tz) // Passe 2 : think activé pour une meilleure qualité d'analyse summary, err := provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 32768}) if err != nil { return nil, fmt.Errorf("AI summarize: %w", err) } fmt.Printf("[pipeline] Passe 2 — terminée en %s\n", time.Since(t2).Round(time.Second)) return p.repo.CreateSummary(userID, summary, &providerCfg.ID) } // filterByRelevance splits articles into batches and asks the AI to select relevant // ones from each batch. Results are pooled then truncated to max. func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, symbols []string, articles []models.Article, max int) []models.Article { batchSizeStr, _ := p.repo.GetSetting("filter_batch_size") batchSize, _ := strconv.Atoi(batchSizeStr) if batchSize <= 0 { batchSize = 20 } var selected []models.Article numBatches := (len(articles) + batchSize - 1) / batchSize for b := 0; b < numBatches; b++ { start := b * batchSize end := start + batchSize if end > len(articles) { end = len(articles) } batch := articles[start:end] fmt.Printf("[pipeline] Passe 1 — batch %d/%d (%d articles)…\n", b+1, numBatches, len(batch)) t := time.Now() chosen := p.filterBatch(ctx, provider, symbols, batch) fmt.Printf("[pipeline] Passe 1 — batch %d/%d terminé en %s : %d retenus\n", b+1, numBatches, time.Since(t).Round(time.Second), len(chosen)) selected = append(selected, chosen...) // Stop early if we have plenty of candidates if len(selected) >= max*2 { fmt.Printf("[pipeline] Passe 1 — suffisamment de candidats (%d), arrêt anticipé\n", len(selected)) break } } if len(selected) <= max { return selected } return selected[:max] } // filterBatch asks the AI to return all relevant articles from a single batch. func (p *Pipeline) filterBatch(ctx context.Context, provider Provider, symbols []string, batch []models.Article) []models.Article { prompt := buildFilterBatchPrompt(symbols, batch) response, err := provider.Summarize(ctx, prompt, GenOptions{Think: false, NumCtx: 4096}) if err != nil { fmt.Printf("[pipeline] filterBatch — échec (%v), conservation du batch entier\n", err) return batch } indices := parseIndexArray(response, len(batch)) if len(indices) == 0 { return nil } filtered := make([]models.Article, 0, len(indices)) for _, i := range indices { filtered = append(filtered, batch[i]) } return filtered } func buildFilterBatchPrompt(symbols []string, batch []models.Article) string { var sb strings.Builder sb.WriteString("Tu es un assistant de trading financier.\n") sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne TOUS ceux pertinents pour un trader actif.\n", len(batch))) if len(symbols) > 0 { sb.WriteString("Actifs prioritaires : ") sb.WriteString(strings.Join(symbols, ", ")) sb.WriteString("\n") } sb.WriteString("\nRéponds UNIQUEMENT avec un tableau JSON des indices retenus (base 0), exemple : [0, 2, 5]\n") sb.WriteString("Si aucun article n'est pertinent, réponds : []\n") sb.WriteString("N'ajoute aucun texte avant ou après le tableau JSON.\n\n") sb.WriteString("Articles :\n") for i, a := range batch { sb.WriteString(fmt.Sprintf("[%d] %s (%s)\n", i, a.Title, a.SourceName)) } return sb.String() } var jsonArrayRe = regexp.MustCompile(`\[[\d\s,]+\]`) func parseIndexArray(response string, maxIndex int) []int { match := jsonArrayRe.FindString(response) if match == "" { return nil } match = strings.Trim(match, "[]") parts := strings.Split(match, ",") seen := make(map[int]bool) var indices []int for _, p := range parts { n, err := strconv.Atoi(strings.TrimSpace(p)) if err != nil || n < 0 || n >= maxIndex || seen[n] { continue } seen[n] = true indices = append(indices, n) } return indices } func (p *Pipeline) GenerateForAll(ctx context.Context) error { users, err := p.repo.ListUsers() if err != nil { return err } for _, user := range users { if _, err := p.GenerateForUser(ctx, user.ID); err != nil { fmt.Printf("summary for user %s: %v\n", user.Email, err) } } return nil } // GenerateReportAsync crée le rapport en DB (status=generating) et lance la génération en arrière-plan. func (p *Pipeline) GenerateReportAsync(reportID, excerpt, question string, mgr *ReportManager) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) mgr.Register(reportID, cancel) go func() { defer cancel() defer mgr.Remove(reportID) answer, err := p.callProviderForReport(ctx, excerpt, question) if err != nil { if ctx.Err() != nil { // annulé volontairement — le rapport est supprimé par le handler return } _ = p.repo.UpdateReport(reportID, "error", "", err.Error()) return } _ = p.repo.UpdateReport(reportID, "done", answer, "") }() } func (p *Pipeline) callProviderForReport(ctx context.Context, excerpt, question string) (string, error) { provider, _, err := p.buildProviderForRole("report") if err != nil { return "", err } prompt := fmt.Sprintf( "Tu es un assistant financier expert. L'utilisateur a sélectionné les extraits suivants d'un résumé de marché :\n\n%s\n\nQuestion de l'utilisateur : %s\n\nRéponds en français, de façon précise et orientée trading.", excerpt, question, ) return provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 16384}) } func buildPrompt(systemPrompt string, symbols []string, articles []models.Article, tz string) string { var sb strings.Builder sb.WriteString(systemPrompt) sb.WriteString("\n\n") if len(symbols) > 0 { sb.WriteString("Le trader surveille particulièrement ces actifs (sois attentif à toute mention) : ") sb.WriteString(strings.Join(symbols, ", ")) sb.WriteString(".\n\n") } loc, err := time.LoadLocation(tz) if err != nil || tz == "" { loc = time.UTC } sb.WriteString(fmt.Sprintf("Date d'analyse : %s\n\n", time.Now().In(loc).Format("02/01/2006 15:04"))) sb.WriteString("## Actualités\n\n") for i, a := range articles { sb.WriteString(fmt.Sprintf("### [%d] %s\n", i+1, a.Title)) sb.WriteString(fmt.Sprintf("Source : %s\n", a.SourceName)) if a.PublishedAt.Valid { sb.WriteString(fmt.Sprintf("Date : %s\n", a.PublishedAt.Time.Format("02/01/2006 15:04"))) } content := a.Content if len(content) > 1000 { content = content[:1000] + "..." } sb.WriteString(content) sb.WriteString("\n\n") } return sb.String() }