330 lines
10 KiB
Go
330 lines
10 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/tradarr/backend/internal/crypto"
|
|
"github.com/tradarr/backend/internal/models"
|
|
)
|
|
|
|
const DefaultSystemPrompt = `Tu es un assistant spécialisé en trading financier. Analyse l'ensemble des actualités suivantes, toutes sources confondues, et crée un résumé global structuré en français, orienté trading.
|
|
|
|
Structure ton résumé ainsi :
|
|
1. **Vue macro** : tendances globales du marché (économie, géopolitique, secteurs)
|
|
2. **Actifs surveillés** : pour chaque actif de la watchlist mentionné dans les news :
|
|
- Sentiment (haussier/baissier/neutre)
|
|
- Faits clés et catalyseurs
|
|
- Risques et opportunités
|
|
3. **Autres mouvements notables** : actifs hors watchlist à surveiller
|
|
4. **Synthèse** : points d'attention prioritaires pour la journée`
|
|
|
|
type Pipeline struct {
|
|
repo *models.Repository
|
|
enc *crypto.Encryptor
|
|
generating atomic.Bool
|
|
}
|
|
|
|
func NewPipeline(repo *models.Repository, enc *crypto.Encryptor) *Pipeline {
|
|
return &Pipeline{repo: repo, enc: enc}
|
|
}
|
|
|
|
func (p *Pipeline) IsGenerating() bool {
|
|
return p.generating.Load()
|
|
}
|
|
|
|
func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error) {
|
|
return NewProvider(name, apiKey, "", endpoint)
|
|
}
|
|
|
|
// buildProviderForRole resolves and builds the AI provider for a given task role.
|
|
func (p *Pipeline) buildProviderForRole(role string) (Provider, *models.AIProvider, error) {
|
|
cfg, model, err := p.repo.GetRoleProvider(role)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("get provider for role %s: %w", role, err)
|
|
}
|
|
if cfg == nil {
|
|
return nil, nil, fmt.Errorf("no AI provider configured for role %s", role)
|
|
}
|
|
apiKey := ""
|
|
if cfg.APIKeyEncrypted != "" {
|
|
apiKey, err = p.enc.Decrypt(cfg.APIKeyEncrypted)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("decrypt API key for role %s: %w", role, err)
|
|
}
|
|
}
|
|
provider, err := NewProvider(cfg.Name, apiKey, model, cfg.Endpoint)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("build provider for role %s: %w", role, err)
|
|
}
|
|
return provider, cfg, nil
|
|
}
|
|
|
|
func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.Summary, error) {
|
|
p.generating.Store(true)
|
|
defer p.generating.Store(false)
|
|
|
|
provider, providerCfg, err := p.buildProviderForRole("summary")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
assets, err := p.repo.GetUserAssets(userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user assets: %w", err)
|
|
}
|
|
symbols := make([]string, len(assets))
|
|
for i, a := range assets {
|
|
symbols[i] = a.Symbol
|
|
}
|
|
|
|
hoursStr, _ := p.repo.GetSetting("articles_lookback_hours")
|
|
hours, _ := strconv.Atoi(hoursStr)
|
|
if hours == 0 {
|
|
hours = 24
|
|
}
|
|
|
|
articles, err := p.repo.GetRecentArticles(hours)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get articles: %w", err)
|
|
}
|
|
if len(articles) == 0 {
|
|
return nil, fmt.Errorf("no recent articles found")
|
|
}
|
|
|
|
maxStr, _ := p.repo.GetSetting("summary_max_articles")
|
|
maxArticles, _ := strconv.Atoi(maxStr)
|
|
if maxArticles == 0 {
|
|
maxArticles = 50
|
|
}
|
|
|
|
// Passe 1 : filtrage par pertinence — seulement si nettement plus d'articles que le max
|
|
if len(articles) > maxArticles*2 {
|
|
filterProvider, _, filterErr := p.buildProviderForRole("filter")
|
|
if filterErr != nil {
|
|
filterProvider = provider // fallback to summary provider
|
|
}
|
|
fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles)
|
|
t1 := time.Now()
|
|
articles = p.filterByRelevance(ctx, filterProvider, symbols, articles, maxArticles)
|
|
fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles))
|
|
} else if len(articles) > maxArticles {
|
|
articles = articles[:maxArticles]
|
|
fmt.Printf("[pipeline] troncature directe à %d articles (pas assez d'excédent pour justifier un appel IA)\n", maxArticles)
|
|
}
|
|
|
|
systemPrompt, _ := p.repo.GetSetting("ai_system_prompt")
|
|
if systemPrompt == "" {
|
|
systemPrompt = DefaultSystemPrompt
|
|
}
|
|
|
|
tz, _ := p.repo.GetSetting("timezone")
|
|
|
|
// Passe 2 : résumé complet
|
|
fmt.Printf("[pipeline] Passe 2 — résumé : génération sur %d articles…\n", len(articles))
|
|
t2 := time.Now()
|
|
prompt := buildPrompt(systemPrompt, symbols, articles, tz)
|
|
// Passe 2 : think activé pour une meilleure qualité d'analyse
|
|
summary, err := provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 32768})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AI summarize: %w", err)
|
|
}
|
|
fmt.Printf("[pipeline] Passe 2 — terminée en %s\n", time.Since(t2).Round(time.Second))
|
|
|
|
return p.repo.CreateSummary(userID, summary, &providerCfg.ID)
|
|
}
|
|
|
|
// filterByRelevance splits articles into batches and asks the AI to select relevant
|
|
// ones from each batch. Results are pooled then truncated to max.
|
|
func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, symbols []string, articles []models.Article, max int) []models.Article {
|
|
batchSizeStr, _ := p.repo.GetSetting("filter_batch_size")
|
|
batchSize, _ := strconv.Atoi(batchSizeStr)
|
|
if batchSize <= 0 {
|
|
batchSize = 20
|
|
}
|
|
|
|
var selected []models.Article
|
|
numBatches := (len(articles) + batchSize - 1) / batchSize
|
|
|
|
for b := 0; b < numBatches; b++ {
|
|
start := b * batchSize
|
|
end := start + batchSize
|
|
if end > len(articles) {
|
|
end = len(articles)
|
|
}
|
|
batch := articles[start:end]
|
|
|
|
fmt.Printf("[pipeline] Passe 1 — batch %d/%d (%d articles)…\n", b+1, numBatches, len(batch))
|
|
t := time.Now()
|
|
chosen := p.filterBatch(ctx, provider, symbols, batch)
|
|
fmt.Printf("[pipeline] Passe 1 — batch %d/%d terminé en %s : %d retenus\n", b+1, numBatches, time.Since(t).Round(time.Second), len(chosen))
|
|
|
|
selected = append(selected, chosen...)
|
|
|
|
// Stop early if we have plenty of candidates
|
|
if len(selected) >= max*2 {
|
|
fmt.Printf("[pipeline] Passe 1 — suffisamment de candidats (%d), arrêt anticipé\n", len(selected))
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(selected) <= max {
|
|
return selected
|
|
}
|
|
return selected[:max]
|
|
}
|
|
|
|
// filterBatch asks the AI to return all relevant articles from a single batch.
|
|
func (p *Pipeline) filterBatch(ctx context.Context, provider Provider, symbols []string, batch []models.Article) []models.Article {
|
|
prompt := buildFilterBatchPrompt(symbols, batch)
|
|
response, err := provider.Summarize(ctx, prompt, GenOptions{Think: false, NumCtx: 4096})
|
|
if err != nil {
|
|
fmt.Printf("[pipeline] filterBatch — échec (%v), conservation du batch entier\n", err)
|
|
return batch
|
|
}
|
|
|
|
indices := parseIndexArray(response, len(batch))
|
|
if len(indices) == 0 {
|
|
return nil
|
|
}
|
|
|
|
filtered := make([]models.Article, 0, len(indices))
|
|
for _, i := range indices {
|
|
filtered = append(filtered, batch[i])
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
func buildFilterBatchPrompt(symbols []string, batch []models.Article) string {
|
|
var sb strings.Builder
|
|
sb.WriteString("Tu es un assistant de trading financier.\n")
|
|
sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne TOUS ceux pertinents pour un trader actif.\n", len(batch)))
|
|
|
|
if len(symbols) > 0 {
|
|
sb.WriteString("Actifs prioritaires : ")
|
|
sb.WriteString(strings.Join(symbols, ", "))
|
|
sb.WriteString("\n")
|
|
}
|
|
|
|
sb.WriteString("\nRéponds UNIQUEMENT avec un tableau JSON des indices retenus (base 0), exemple : [0, 2, 5]\n")
|
|
sb.WriteString("Si aucun article n'est pertinent, réponds : []\n")
|
|
sb.WriteString("N'ajoute aucun texte avant ou après le tableau JSON.\n\n")
|
|
sb.WriteString("Articles :\n")
|
|
|
|
for i, a := range batch {
|
|
sb.WriteString(fmt.Sprintf("[%d] %s (%s)\n", i, a.Title, a.SourceName))
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
var jsonArrayRe = regexp.MustCompile(`\[[\d\s,]+\]`)
|
|
|
|
func parseIndexArray(response string, maxIndex int) []int {
|
|
match := jsonArrayRe.FindString(response)
|
|
if match == "" {
|
|
return nil
|
|
}
|
|
match = strings.Trim(match, "[]")
|
|
parts := strings.Split(match, ",")
|
|
|
|
seen := make(map[int]bool)
|
|
var indices []int
|
|
for _, p := range parts {
|
|
n, err := strconv.Atoi(strings.TrimSpace(p))
|
|
if err != nil || n < 0 || n >= maxIndex || seen[n] {
|
|
continue
|
|
}
|
|
seen[n] = true
|
|
indices = append(indices, n)
|
|
}
|
|
return indices
|
|
}
|
|
|
|
func (p *Pipeline) GenerateForAll(ctx context.Context) error {
|
|
users, err := p.repo.ListUsers()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, user := range users {
|
|
if _, err := p.GenerateForUser(ctx, user.ID); err != nil {
|
|
fmt.Printf("summary for user %s: %v\n", user.Email, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GenerateReportAsync crée le rapport en DB (status=generating) et lance la génération en arrière-plan.
|
|
func (p *Pipeline) GenerateReportAsync(reportID, excerpt, question string, mgr *ReportManager) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
|
mgr.Register(reportID, cancel)
|
|
|
|
go func() {
|
|
defer cancel()
|
|
defer mgr.Remove(reportID)
|
|
|
|
answer, err := p.callProviderForReport(ctx, excerpt, question)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
// annulé volontairement — le rapport est supprimé par le handler
|
|
return
|
|
}
|
|
_ = p.repo.UpdateReport(reportID, "error", "", err.Error())
|
|
return
|
|
}
|
|
_ = p.repo.UpdateReport(reportID, "done", answer, "")
|
|
}()
|
|
}
|
|
|
|
func (p *Pipeline) callProviderForReport(ctx context.Context, excerpt, question string) (string, error) {
|
|
provider, _, err := p.buildProviderForRole("report")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
prompt := fmt.Sprintf(
|
|
"Tu es un assistant financier expert. L'utilisateur a sélectionné les extraits suivants d'un résumé de marché :\n\n%s\n\nQuestion de l'utilisateur : %s\n\nRéponds en français, de façon précise et orientée trading.",
|
|
excerpt, question,
|
|
)
|
|
|
|
return provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 16384})
|
|
}
|
|
|
|
func buildPrompt(systemPrompt string, symbols []string, articles []models.Article, tz string) string {
|
|
var sb strings.Builder
|
|
sb.WriteString(systemPrompt)
|
|
sb.WriteString("\n\n")
|
|
if len(symbols) > 0 {
|
|
sb.WriteString("Le trader surveille particulièrement ces actifs (sois attentif à toute mention) : ")
|
|
sb.WriteString(strings.Join(symbols, ", "))
|
|
sb.WriteString(".\n\n")
|
|
}
|
|
loc, err := time.LoadLocation(tz)
|
|
if err != nil || tz == "" {
|
|
loc = time.UTC
|
|
}
|
|
sb.WriteString(fmt.Sprintf("Date d'analyse : %s\n\n", time.Now().In(loc).Format("02/01/2006 15:04")))
|
|
sb.WriteString("## Actualités\n\n")
|
|
|
|
for i, a := range articles {
|
|
sb.WriteString(fmt.Sprintf("### [%d] %s\n", i+1, a.Title))
|
|
sb.WriteString(fmt.Sprintf("Source : %s\n", a.SourceName))
|
|
if a.PublishedAt.Valid {
|
|
sb.WriteString(fmt.Sprintf("Date : %s\n", a.PublishedAt.Time.Format("02/01/2006 15:04")))
|
|
}
|
|
content := a.Content
|
|
if len(content) > 1000 {
|
|
content = content[:1000] + "..."
|
|
}
|
|
sb.WriteString(content)
|
|
sb.WriteString("\n\n")
|
|
}
|
|
|
|
return sb.String()
|
|
}
|