feat: add feature to speak with the AI and create report from contexts
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tradarr/backend/internal/crypto"
|
||||
@ -24,14 +25,19 @@ Structure ton résumé ainsi :
|
||||
4. **Synthèse** : points d'attention prioritaires pour la journée`
|
||||
|
||||
type Pipeline struct {
|
||||
repo *models.Repository
|
||||
enc *crypto.Encryptor
|
||||
repo *models.Repository
|
||||
enc *crypto.Encryptor
|
||||
generating atomic.Bool
|
||||
}
|
||||
|
||||
func NewPipeline(repo *models.Repository, enc *crypto.Encryptor) *Pipeline {
|
||||
return &Pipeline{repo: repo, enc: enc}
|
||||
}
|
||||
|
||||
func (p *Pipeline) IsGenerating() bool {
|
||||
return p.generating.Load()
|
||||
}
|
||||
|
||||
func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error) {
|
||||
provider, err := p.repo.GetActiveAIProvider()
|
||||
if err != nil {
|
||||
@ -45,6 +51,8 @@ func (p *Pipeline) BuildProvider(name, apiKey, endpoint string) (Provider, error
|
||||
}
|
||||
|
||||
func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.Summary, error) {
|
||||
p.generating.Store(true)
|
||||
defer p.generating.Store(false)
|
||||
providerCfg, err := p.repo.GetActiveAIProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get active provider: %w", err)
|
||||
@ -97,9 +105,10 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
||||
|
||||
// Passe 1 : filtrage par pertinence sur les titres si trop d'articles
|
||||
if len(articles) > maxArticles {
|
||||
fmt.Printf("pipeline: %d articles → filtering to %d via AI\n", len(articles), maxArticles)
|
||||
fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles)
|
||||
t1 := time.Now()
|
||||
articles = p.filterByRelevance(ctx, provider, symbols, articles, maxArticles)
|
||||
fmt.Printf("pipeline: %d articles retained after filtering\n", len(articles))
|
||||
fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles))
|
||||
}
|
||||
|
||||
systemPrompt, _ := p.repo.GetSetting("ai_system_prompt")
|
||||
@ -108,11 +117,14 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
||||
}
|
||||
|
||||
// Passe 2 : résumé complet
|
||||
fmt.Printf("[pipeline] Passe 2 — résumé : génération sur %d articles…\n", len(articles))
|
||||
t2 := time.Now()
|
||||
prompt := buildPrompt(systemPrompt, symbols, articles)
|
||||
summary, err := provider.Summarize(ctx, prompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AI summarize: %w", err)
|
||||
}
|
||||
fmt.Printf("[pipeline] Passe 2 — terminée en %s\n", time.Since(t2).Round(time.Second))
|
||||
|
||||
return p.repo.CreateSummary(userID, summary, &providerCfg.ID)
|
||||
}
|
||||
@ -123,13 +135,13 @@ func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, sym
|
||||
prompt := buildFilterPrompt(symbols, articles, max)
|
||||
response, err := provider.Summarize(ctx, prompt)
|
||||
if err != nil {
|
||||
fmt.Printf("pipeline: filter AI call failed (%v), falling back to truncation\n", err)
|
||||
fmt.Printf("[pipeline] Passe 1 — échec (%v), repli sur troncature\n", err)
|
||||
return articles[:max]
|
||||
}
|
||||
|
||||
indices := parseIndexArray(response, len(articles))
|
||||
if len(indices) == 0 {
|
||||
fmt.Printf("pipeline: could not parse filter response, falling back to truncation\n")
|
||||
fmt.Printf("[pipeline] Passe 1 — réponse non parseable, repli sur troncature\n")
|
||||
return articles[:max]
|
||||
}
|
||||
|
||||
@ -201,6 +213,58 @@ func (p *Pipeline) GenerateForAll(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateReportAsync crée le rapport en DB (status=generating) et lance la génération en arrière-plan.
|
||||
func (p *Pipeline) GenerateReportAsync(reportID, excerpt, question string, mgr *ReportManager) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
mgr.Register(reportID, cancel)
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
defer mgr.Remove(reportID)
|
||||
|
||||
answer, err := p.callProviderForReport(ctx, excerpt, question)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// annulé volontairement — le rapport est supprimé par le handler
|
||||
return
|
||||
}
|
||||
_ = p.repo.UpdateReport(reportID, "error", "", err.Error())
|
||||
return
|
||||
}
|
||||
_ = p.repo.UpdateReport(reportID, "done", answer, "")
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Pipeline) callProviderForReport(ctx context.Context, excerpt, question string) (string, error) {
|
||||
providerCfg, err := p.repo.GetActiveAIProvider()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get active provider: %w", err)
|
||||
}
|
||||
if providerCfg == nil {
|
||||
return "", fmt.Errorf("no active AI provider configured")
|
||||
}
|
||||
|
||||
apiKey := ""
|
||||
if providerCfg.APIKeyEncrypted != "" {
|
||||
apiKey, err = p.enc.Decrypt(providerCfg.APIKeyEncrypted)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt API key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := NewProvider(providerCfg.Name, apiKey, providerCfg.Model, providerCfg.Endpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build provider: %w", err)
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(
|
||||
"Tu es un assistant financier expert. L'utilisateur a sélectionné les extraits suivants d'un résumé de marché :\n\n%s\n\nQuestion de l'utilisateur : %s\n\nRéponds en français, de façon précise et orientée trading.",
|
||||
excerpt, question,
|
||||
)
|
||||
|
||||
return provider.Summarize(ctx, prompt)
|
||||
}
|
||||
|
||||
func buildPrompt(systemPrompt string, symbols []string, articles []models.Article) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(systemPrompt)
|
||||
|
||||
37
backend/internal/ai/report_manager.go
Normal file
37
backend/internal/ai/report_manager.go
Normal file
@ -0,0 +1,37 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ReportManager tracks in-flight report goroutines so they can be cancelled.
|
||||
type ReportManager struct {
|
||||
mu sync.Mutex
|
||||
cancels map[string]context.CancelFunc
|
||||
}
|
||||
|
||||
func NewReportManager() *ReportManager {
|
||||
return &ReportManager{cancels: make(map[string]context.CancelFunc)}
|
||||
}
|
||||
|
||||
func (m *ReportManager) Register(id string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancels[id] = cancel
|
||||
}
|
||||
|
||||
func (m *ReportManager) Cancel(id string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cancel, ok := m.cancels[id]; ok {
|
||||
cancel()
|
||||
delete(m.cancels, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ReportManager) Remove(id string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancels, id)
|
||||
}
|
||||
@ -10,12 +10,13 @@ import (
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
repo *models.Repository
|
||||
cfg *config.Config
|
||||
enc *crypto.Encryptor
|
||||
registry *scraper.Registry
|
||||
pipeline *ai.Pipeline
|
||||
scheduler *scheduler.Scheduler
|
||||
repo *models.Repository
|
||||
cfg *config.Config
|
||||
enc *crypto.Encryptor
|
||||
registry *scraper.Registry
|
||||
pipeline *ai.Pipeline
|
||||
scheduler *scheduler.Scheduler
|
||||
reportManager *ai.ReportManager
|
||||
}
|
||||
|
||||
func New(
|
||||
@ -27,11 +28,12 @@ func New(
|
||||
sched *scheduler.Scheduler,
|
||||
) *Handler {
|
||||
return &Handler{
|
||||
repo: repo,
|
||||
cfg: cfg,
|
||||
enc: enc,
|
||||
registry: registry,
|
||||
pipeline: pipeline,
|
||||
scheduler: sched,
|
||||
repo: repo,
|
||||
cfg: cfg,
|
||||
enc: enc,
|
||||
registry: registry,
|
||||
pipeline: pipeline,
|
||||
scheduler: sched,
|
||||
reportManager: ai.NewReportManager(),
|
||||
}
|
||||
}
|
||||
|
||||
84
backend/internal/api/handlers/reports.go
Normal file
84
backend/internal/api/handlers/reports.go
Normal file
@ -0,0 +1,84 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tradarr/backend/internal/httputil"
|
||||
)
|
||||
|
||||
type reportRequest struct {
|
||||
SummaryID string `json:"summary_id"`
|
||||
Excerpts []string `json:"excerpts" binding:"required,min=1"`
|
||||
Question string `json:"question" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateReport(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
var req reportRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
httputil.BadRequest(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var summaryID *string
|
||||
if req.SummaryID != "" {
|
||||
summaryID = &req.SummaryID
|
||||
}
|
||||
|
||||
// Joindre les extraits avec un séparateur visuel
|
||||
excerpt := buildExcerptContext(req.Excerpts)
|
||||
|
||||
// Créer le rapport en DB avec status=generating, retourner immédiatement
|
||||
report, err := h.repo.CreatePendingReport(userID, summaryID, excerpt, req.Question)
|
||||
if err != nil {
|
||||
httputil.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Lancer la génération en arrière-plan
|
||||
h.pipeline.GenerateReportAsync(report.ID, excerpt, req.Question, h.reportManager)
|
||||
|
||||
c.JSON(http.StatusCreated, report)
|
||||
}
|
||||
|
||||
func (h *Handler) ListReports(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
reports, err := h.repo.ListReports(userID)
|
||||
if err != nil {
|
||||
httputil.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
httputil.OK(c, reports)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteReport(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
id := c.Param("id")
|
||||
// Annuler la goroutine si elle tourne encore
|
||||
h.reportManager.Cancel(id)
|
||||
if err := h.repo.DeleteReport(id, userID); err != nil {
|
||||
httputil.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *Handler) GetGeneratingStatus(c *gin.Context) {
|
||||
httputil.OK(c, gin.H{"generating": h.pipeline.IsGenerating()})
|
||||
}
|
||||
|
||||
func buildExcerptContext(excerpts []string) string {
|
||||
if len(excerpts) == 1 {
|
||||
return excerpts[0]
|
||||
}
|
||||
var sb strings.Builder
|
||||
for i, e := range excerpts {
|
||||
if i > 0 {
|
||||
sb.WriteString("\n\n---\n\n")
|
||||
}
|
||||
sb.WriteString(e)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@ -7,7 +7,11 @@ import (
|
||||
)
|
||||
|
||||
func SetupRouter(h *handlers.Handler, jwtSecret string) *gin.Engine {
|
||||
r := gin.Default()
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(gin.LoggerWithConfig(gin.LoggerConfig{
|
||||
SkipPaths: []string{"/api/summaries/status"},
|
||||
}))
|
||||
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
@ -39,8 +43,13 @@ func SetupRouter(h *handlers.Handler, jwtSecret string) *gin.Engine {
|
||||
authed.GET("/articles/:id", h.GetArticle)
|
||||
|
||||
authed.GET("/summaries", h.ListSummaries)
|
||||
authed.GET("/summaries/status", h.GetGeneratingStatus)
|
||||
authed.POST("/summaries/generate", h.GenerateSummary)
|
||||
|
||||
authed.GET("/reports", h.ListReports)
|
||||
authed.POST("/reports", h.CreateReport)
|
||||
authed.DELETE("/reports/:id", h.DeleteReport)
|
||||
|
||||
// Admin
|
||||
admin := authed.Group("/admin")
|
||||
admin.Use(auth.AdminOnly())
|
||||
|
||||
@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS reports;
|
||||
@ -0,0 +1,9 @@
|
||||
CREATE TABLE reports (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
summary_id UUID REFERENCES summaries(id) ON DELETE SET NULL,
|
||||
context_excerpt TEXT NOT NULL,
|
||||
question TEXT NOT NULL,
|
||||
answer TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
@ -0,0 +1,4 @@
|
||||
ALTER TABLE reports
|
||||
DROP COLUMN IF EXISTS status,
|
||||
DROP COLUMN IF EXISTS error_msg,
|
||||
ALTER COLUMN answer DROP DEFAULT;
|
||||
@ -0,0 +1,4 @@
|
||||
ALTER TABLE reports
|
||||
ALTER COLUMN answer SET DEFAULT '',
|
||||
ADD COLUMN status VARCHAR(20) NOT NULL DEFAULT 'done',
|
||||
ADD COLUMN error_msg TEXT NOT NULL DEFAULT '';
|
||||
@ -104,3 +104,15 @@ type ScheduleSlot struct {
|
||||
Hour int `json:"hour"`
|
||||
Minute int `json:"minute"`
|
||||
}
|
||||
|
||||
type Report struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
SummaryID *string `json:"summary_id"`
|
||||
ContextExcerpt string `json:"context_excerpt"`
|
||||
Question string `json:"question"`
|
||||
Answer string `json:"answer"`
|
||||
Status string `json:"status"` // generating | done | error
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
@ -187,20 +187,28 @@ func (r *Repository) UpdateSource(id string, enabled bool) error {
|
||||
|
||||
// ── Articles ───────────────────────────────────────────────────────────────
|
||||
|
||||
func (r *Repository) UpsertArticle(sourceID, title, content, url string, publishedAt *time.Time) (*Article, error) {
|
||||
a := &Article{}
|
||||
// InsertArticleIfNew insère l'article uniquement s'il n'existe pas déjà (par URL).
|
||||
// Retourne (article, true, nil) si inséré, (nil, false, nil) si déjà présent.
|
||||
func (r *Repository) InsertArticleIfNew(sourceID, title, content, url string, publishedAt *time.Time) (*Article, bool, error) {
|
||||
var pa sql.NullTime
|
||||
if publishedAt != nil {
|
||||
pa = sql.NullTime{Time: *publishedAt, Valid: true}
|
||||
}
|
||||
a := &Article{}
|
||||
err := r.db.QueryRow(`
|
||||
INSERT INTO articles (source_id, title, content, url, published_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (url) DO UPDATE SET title=EXCLUDED.title, content=EXCLUDED.content
|
||||
ON CONFLICT (url) DO NOTHING
|
||||
RETURNING id, source_id, title, content, url, published_at, created_at`,
|
||||
sourceID, title, content, url, pa,
|
||||
).Scan(&a.ID, &a.SourceID, &a.Title, &a.Content, &a.URL, &a.PublishedAt, &a.CreatedAt)
|
||||
return a, err
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, false, nil // déjà présent
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return a, true, nil
|
||||
}
|
||||
|
||||
func (r *Repository) AddArticleSymbol(articleID, symbol string) error {
|
||||
@ -260,7 +268,7 @@ func (r *Repository) GetRecentArticles(hours int) ([]Article, error) {
|
||||
SELECT a.id, a.source_id, s.name, a.title, a.content, a.url, a.published_at, a.created_at
|
||||
FROM articles a
|
||||
JOIN sources s ON s.id = a.source_id
|
||||
WHERE a.created_at > NOW() - ($1 * INTERVAL '1 hour')
|
||||
WHERE COALESCE(a.published_at, a.created_at) > NOW() - ($1 * INTERVAL '1 hour')
|
||||
ORDER BY a.published_at DESC NULLS LAST, a.created_at DESC`, hours)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -581,3 +589,48 @@ func (r *Repository) ListSettings() ([]Setting, error) {
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// ── Reports ────────────────────────────────────────────────────────────────
|
||||
|
||||
func (r *Repository) CreatePendingReport(userID string, summaryID *string, excerpt, question string) (*Report, error) {
|
||||
rep := &Report{}
|
||||
err := r.db.QueryRow(`
|
||||
INSERT INTO reports (user_id, summary_id, context_excerpt, question, answer, status)
|
||||
VALUES ($1, $2, $3, $4, '', 'generating')
|
||||
RETURNING id, user_id, summary_id, context_excerpt, question, answer, status, error_msg, created_at`,
|
||||
userID, summaryID, excerpt, question,
|
||||
).Scan(&rep.ID, &rep.UserID, &rep.SummaryID, &rep.ContextExcerpt, &rep.Question, &rep.Answer, &rep.Status, &rep.ErrorMsg, &rep.CreatedAt)
|
||||
return rep, err
|
||||
}
|
||||
|
||||
func (r *Repository) UpdateReport(id, status, answer, errorMsg string) error {
|
||||
_, err := r.db.Exec(`
|
||||
UPDATE reports SET status=$1, answer=$2, error_msg=$3 WHERE id=$4`,
|
||||
status, answer, errorMsg, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Repository) ListReports(userID string) ([]Report, error) {
|
||||
rows, err := r.db.Query(`
|
||||
SELECT id, user_id, summary_id, context_excerpt, question, answer, status, error_msg, created_at
|
||||
FROM reports WHERE user_id=$1
|
||||
ORDER BY created_at DESC`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var reports []Report
|
||||
for rows.Next() {
|
||||
var rep Report
|
||||
if err := rows.Scan(&rep.ID, &rep.UserID, &rep.SummaryID, &rep.ContextExcerpt, &rep.Question, &rep.Answer, &rep.Status, &rep.ErrorMsg, &rep.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reports = append(reports, rep)
|
||||
}
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
func (r *Repository) DeleteReport(id, userID string) error {
|
||||
_, err := r.db.Exec(`DELETE FROM reports WHERE id=$1 AND user_id=$2`, id, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -72,11 +72,11 @@ func (r *Registry) Run(sourceID string) error {
|
||||
return scrapeErr
|
||||
}
|
||||
|
||||
// Persister les articles
|
||||
// Persister uniquement les nouveaux articles
|
||||
count := 0
|
||||
for _, a := range articles {
|
||||
saved, err := r.repo.UpsertArticle(sourceID, a.Title, a.Content, a.URL, a.PublishedAt)
|
||||
if err != nil {
|
||||
saved, isNew, err := r.repo.InsertArticleIfNew(sourceID, a.Title, a.Content, a.URL, a.PublishedAt)
|
||||
if err != nil || !isNew {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
|
||||
Reference in New Issue
Block a user