fix: add /think and /no-think added to API to run with ollama
This commit is contained in:
@ -28,7 +28,7 @@ func newAnthropic(apiKey, model string) *anthropicProvider {
|
|||||||
|
|
||||||
func (p *anthropicProvider) Name() string { return "anthropic" }
|
func (p *anthropicProvider) Name() string { return "anthropic" }
|
||||||
|
|
||||||
func (p *anthropicProvider) Summarize(ctx context.Context, prompt string) (string, error) {
|
func (p *anthropicProvider) Summarize(ctx context.Context, prompt string, _ GenOptions) (string, error) {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"model": p.model,
|
"model": p.model,
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ func newGemini(apiKey, model string) *geminiProvider {
|
|||||||
|
|
||||||
func (p *geminiProvider) Name() string { return "gemini" }
|
func (p *geminiProvider) Name() string { return "gemini" }
|
||||||
|
|
||||||
func (p *geminiProvider) Summarize(ctx context.Context, prompt string) (string, error) {
|
func (p *geminiProvider) Summarize(ctx context.Context, prompt string, _ GenOptions) (string, error) {
|
||||||
url := fmt.Sprintf(
|
url := fmt.Sprintf(
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s",
|
"https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s",
|
||||||
p.model, p.apiKey,
|
p.model, p.apiKey,
|
||||||
|
|||||||
@ -31,13 +31,18 @@ func newOllama(endpoint, model string) *ollamaProvider {
|
|||||||
|
|
||||||
func (p *ollamaProvider) Name() string { return "ollama" }
|
func (p *ollamaProvider) Name() string { return "ollama" }
|
||||||
|
|
||||||
func (p *ollamaProvider) Summarize(ctx context.Context, prompt string) (string, error) {
|
func (p *ollamaProvider) Summarize(ctx context.Context, prompt string, opts GenOptions) (string, error) {
|
||||||
|
numCtx := 32768
|
||||||
|
if opts.NumCtx > 0 {
|
||||||
|
numCtx = opts.NumCtx
|
||||||
|
}
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"model": p.model,
|
"model": p.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": false,
|
"stream": false,
|
||||||
|
"think": opts.Think,
|
||||||
"options": map[string]interface{}{
|
"options": map[string]interface{}{
|
||||||
"num_ctx": 32768,
|
"num_ctx": numCtx,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
b, _ := json.Marshal(body)
|
b, _ := json.Marshal(body)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ func newOpenAI(apiKey, model string) *openAIProvider {
|
|||||||
|
|
||||||
func (p *openAIProvider) Name() string { return "openai" }
|
func (p *openAIProvider) Name() string { return "openai" }
|
||||||
|
|
||||||
func (p *openAIProvider) Summarize(ctx context.Context, prompt string) (string, error) {
|
func (p *openAIProvider) Summarize(ctx context.Context, prompt string, _ GenOptions) (string, error) {
|
||||||
resp, err := p.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
resp, err := p.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
|
||||||
Model: p.model,
|
Model: p.model,
|
||||||
Messages: []openai.ChatCompletionMessage{
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
|||||||
@ -103,12 +103,15 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
|||||||
maxArticles = 50
|
maxArticles = 50
|
||||||
}
|
}
|
||||||
|
|
||||||
// Passe 1 : filtrage par pertinence sur les titres si trop d'articles
|
// Passe 1 : filtrage par pertinence — seulement si nettement plus d'articles que le max
|
||||||
if len(articles) > maxArticles {
|
if len(articles) > maxArticles*2 {
|
||||||
fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles)
|
fmt.Printf("[pipeline] Passe 1 — filtrage : %d articles → sélection des %d plus pertinents…\n", len(articles), maxArticles)
|
||||||
t1 := time.Now()
|
t1 := time.Now()
|
||||||
articles = p.filterByRelevance(ctx, provider, symbols, articles, maxArticles)
|
articles = p.filterByRelevance(ctx, provider, symbols, articles, maxArticles)
|
||||||
fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles))
|
fmt.Printf("[pipeline] Passe 1 — terminée en %s : %d articles retenus\n", time.Since(t1).Round(time.Second), len(articles))
|
||||||
|
} else if len(articles) > maxArticles {
|
||||||
|
articles = articles[:maxArticles]
|
||||||
|
fmt.Printf("[pipeline] troncature directe à %d articles (pas assez d'excédent pour justifier un appel IA)\n", maxArticles)
|
||||||
}
|
}
|
||||||
|
|
||||||
systemPrompt, _ := p.repo.GetSetting("ai_system_prompt")
|
systemPrompt, _ := p.repo.GetSetting("ai_system_prompt")
|
||||||
@ -122,7 +125,8 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
|||||||
fmt.Printf("[pipeline] Passe 2 — résumé : génération sur %d articles…\n", len(articles))
|
fmt.Printf("[pipeline] Passe 2 — résumé : génération sur %d articles…\n", len(articles))
|
||||||
t2 := time.Now()
|
t2 := time.Now()
|
||||||
prompt := buildPrompt(systemPrompt, symbols, articles, tz)
|
prompt := buildPrompt(systemPrompt, symbols, articles, tz)
|
||||||
summary, err := provider.Summarize(ctx, prompt)
|
// Passe 2 : think activé pour une meilleure qualité d'analyse
|
||||||
|
summary, err := provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 32768})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("AI summarize: %w", err)
|
return nil, fmt.Errorf("AI summarize: %w", err)
|
||||||
}
|
}
|
||||||
@ -135,7 +139,8 @@ func (p *Pipeline) GenerateForUser(ctx context.Context, userID string) (*models.
|
|||||||
// en ne lui envoyant que les titres (prompt très court = rapide).
|
// en ne lui envoyant que les titres (prompt très court = rapide).
|
||||||
func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, symbols []string, articles []models.Article, max int) []models.Article {
|
func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, symbols []string, articles []models.Article, max int) []models.Article {
|
||||||
prompt := buildFilterPrompt(symbols, articles, max)
|
prompt := buildFilterPrompt(symbols, articles, max)
|
||||||
response, err := provider.Summarize(ctx, prompt)
|
// Passe 1 : pas de think, contexte réduit (titres seulement = prompt court)
|
||||||
|
response, err := provider.Summarize(ctx, prompt, GenOptions{Think: false, NumCtx: 8192})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[pipeline] Passe 1 — échec (%v), repli sur troncature\n", err)
|
fmt.Printf("[pipeline] Passe 1 — échec (%v), repli sur troncature\n", err)
|
||||||
return articles[:max]
|
return articles[:max]
|
||||||
@ -159,7 +164,6 @@ func (p *Pipeline) filterByRelevance(ctx context.Context, provider Provider, sym
|
|||||||
|
|
||||||
func buildFilterPrompt(symbols []string, articles []models.Article, max int) string {
|
func buildFilterPrompt(symbols []string, articles []models.Article, max int) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("/no_think\n")
|
|
||||||
sb.WriteString("Tu es un assistant de trading financier. ")
|
sb.WriteString("Tu es un assistant de trading financier. ")
|
||||||
sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne les %d plus pertinents pour un trader actif.\n", len(articles), max))
|
sb.WriteString(fmt.Sprintf("Parmi les %d articles ci-dessous, sélectionne les %d plus pertinents pour un trader actif.\n", len(articles), max))
|
||||||
|
|
||||||
@ -265,7 +269,7 @@ func (p *Pipeline) callProviderForReport(ctx context.Context, excerpt, question
|
|||||||
excerpt, question,
|
excerpt, question,
|
||||||
)
|
)
|
||||||
|
|
||||||
return provider.Summarize(ctx, prompt)
|
return provider.Summarize(ctx, prompt, GenOptions{Think: true, NumCtx: 16384})
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildPrompt(systemPrompt string, symbols []string, articles []models.Article, tz string) string {
|
func buildPrompt(systemPrompt string, symbols []string, articles []models.Article, tz string) string {
|
||||||
@ -282,7 +286,6 @@ func buildPrompt(systemPrompt string, symbols []string, articles []models.Articl
|
|||||||
loc = time.UTC
|
loc = time.UTC
|
||||||
}
|
}
|
||||||
sb.WriteString(fmt.Sprintf("Date d'analyse : %s\n\n", time.Now().In(loc).Format("02/01/2006 15:04")))
|
sb.WriteString(fmt.Sprintf("Date d'analyse : %s\n\n", time.Now().In(loc).Format("02/01/2006 15:04")))
|
||||||
sb.WriteString("/think\n\n")
|
|
||||||
sb.WriteString("## Actualités\n\n")
|
sb.WriteString("## Actualités\n\n")
|
||||||
|
|
||||||
for i, a := range articles {
|
for i, a := range articles {
|
||||||
|
|||||||
@ -5,9 +5,15 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GenOptions permet de contrôler le comportement de génération par appel.
|
||||||
|
type GenOptions struct {
|
||||||
|
Think bool // active le mode raisonnement (Qwen3 /think)
|
||||||
|
NumCtx int // taille du contexte KV (0 = valeur par défaut du provider)
|
||||||
|
}
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Name() string
|
Name() string
|
||||||
Summarize(ctx context.Context, prompt string) (string, error)
|
Summarize(ctx context.Context, prompt string, opts GenOptions) (string, error)
|
||||||
ListModels(ctx context.Context) ([]string, error)
|
ListModels(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user