161 lines
4.7 KiB
Go
161 lines
4.7 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/tradarr/backend/internal/crypto"
|
|
"github.com/tradarr/backend/internal/models"
|
|
)
|
|
|
|
const DefaultSystemPrompt = `Tu es un assistant spécialisé en trading financier. Analyse l'ensemble des actualités suivantes, toutes sources confondues, et crée un résumé global structuré en français, orienté trading.
|
|
|
|
Structure ton résumé ainsi :
|
|
1. **Vue macro** : tendances globales du marché (économie, géopolitique, secteurs)
|
|
2. **Actifs surveillés** : pour chaque actif de la watchlist mentionné dans les news :
|
|
- Sentiment (haussier/baissier/neutre)
|
|
- Faits clés et catalyseurs
|
|
- Risques et opportunités
|
|
3. **Autres mouvements notables** : actifs hors watchlist à surveiller
|
|
4. **Synthèse** : points d'attention prioritaires pour la journée`
|
|
|
|
type Pipeline struct {
|
|
repo *models.Repository
|
|
enc *crypto.Encryptor
|
|
}
|
|
|
|
func NewPipeline(repo *models.Repository, enc *crypto.Encryptor) *Pipeline {
|
|
return &Pipeline{repo: repo, enc: enc}
|
|
}
|
|
|
|
// BuildProvider instancie un provider à partir de ses paramètres
|
|
func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error) {
|
|
provider, err := p.repo.GetActiveAIProvider()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
model := ""
|
|
if provider != nil {
|
|
model = provider.Model
|
|
}
|
|
return NewProvider(name, apiKey, model, endpoint)
|
|
}
|
|
|
|
// GenerateForUser génère un résumé personnalisé pour un utilisateur
|
|
func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.Summary, error) {
|
|
// Récupérer le provider actif
|
|
providerCfg, err := p.repo.GetActiveAIProvider()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get active provider: %w", err)
|
|
}
|
|
if providerCfg == nil {
|
|
return nil, fmt.Errorf("no active AI provider configured")
|
|
}
|
|
|
|
apiKey := ""
|
|
if providerCfg.APIKeyEncrypted != "" {
|
|
apiKey, err = p.enc.Decrypt(providerCfg.APIKeyEncrypted)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decrypt API key: %w", err)
|
|
}
|
|
}
|
|
|
|
provider, err := NewProvider(providerCfg.Name, apiKey, providerCfg.Model, providerCfg.Endpoint)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build provider: %w", err)
|
|
}
|
|
|
|
// Récupérer la watchlist de l'utilisateur (pour le contexte IA uniquement)
|
|
assets, err := p.repo.GetUserAssets(userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user assets: %w", err)
|
|
}
|
|
symbols := make([]string, len(assets))
|
|
for i, a := range assets {
|
|
symbols[i] = a.Symbol
|
|
}
|
|
|
|
// Récupérer TOUS les articles récents, toutes sources confondues
|
|
hoursStr, _ := p.repo.GetSetting("articles_lookback_hours")
|
|
hours, _ := strconv.Atoi(hoursStr)
|
|
if hours == 0 {
|
|
hours = 24
|
|
}
|
|
|
|
articles, err := p.repo.GetRecentArticles(hours)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get articles: %w", err)
|
|
}
|
|
if len(articles) == 0 {
|
|
return nil, fmt.Errorf("no recent articles found")
|
|
}
|
|
|
|
maxStr, _ := p.repo.GetSetting("summary_max_articles")
|
|
maxArticles, _ := strconv.Atoi(maxStr)
|
|
if maxArticles == 0 {
|
|
maxArticles = 50
|
|
}
|
|
if len(articles) > maxArticles {
|
|
articles = articles[:maxArticles]
|
|
}
|
|
|
|
systemPrompt, _ := p.repo.GetSetting("ai_system_prompt")
|
|
if systemPrompt == "" {
|
|
systemPrompt = DefaultSystemPrompt
|
|
}
|
|
prompt := buildPrompt(systemPrompt, symbols, articles)
|
|
|
|
summary, err := provider.Summarize(ctx, prompt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("AI summarize: %w", err)
|
|
}
|
|
|
|
return p.repo.CreateSummary(userID, summary, &providerCfg.ID)
|
|
}
|
|
|
|
// GenerateForAll génère les résumés pour tous les utilisateurs ayant une watchlist
|
|
func (p *Pipeline) GenerateForAll(ctx context.Context) error {
|
|
users, err := p.repo.ListUsers()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, user := range users {
|
|
if _, err := p.GenerateForUser(ctx, user.ID); err != nil {
|
|
fmt.Printf("summary for user %s: %v\n", user.Email, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func buildPrompt(systemPrompt string, symbols []string, articles []models.Article) string {
|
|
var sb strings.Builder
|
|
sb.WriteString(systemPrompt)
|
|
sb.WriteString("\n\n")
|
|
if len(symbols) > 0 {
|
|
sb.WriteString("Le trader surveille particulièrement ces actifs (sois attentif à toute mention) : ")
|
|
sb.WriteString(strings.Join(symbols, ", "))
|
|
sb.WriteString(".\n\n")
|
|
}
|
|
sb.WriteString(fmt.Sprintf("Date d'analyse : %s\n\n", time.Now().Format("02/01/2006 15:04")))
|
|
sb.WriteString("## Actualités\n\n")
|
|
|
|
for i, a := range articles {
|
|
sb.WriteString(fmt.Sprintf("### [%d] %s\n", i+1, a.Title))
|
|
sb.WriteString(fmt.Sprintf("Source : %s\n", a.SourceName))
|
|
if a.PublishedAt.Valid {
|
|
sb.WriteString(fmt.Sprintf("Date : %s\n", a.PublishedAt.Time.Format("02/01/2006 15:04")))
|
|
}
|
|
content := a.Content
|
|
if len(content) > 1000 {
|
|
content = content[:1000] + "..."
|
|
}
|
|
sb.WriteString(content)
|
|
sb.WriteString("\n\n")
|
|
}
|
|
|
|
return sb.String()
|
|
}
|