fix: fix gemini models list
This commit is contained in:
@ -74,11 +74,43 @@ func (p *geminiProvider) Summarize(ctx context.Context, prompt string, _ GenOpti
|
||||
return result.Candidates[0].Content.Parts[0].Text, nil
|
||||
}
|
||||
|
||||
func (p *geminiProvider) ListModels(_ context.Context) ([]string, error) {
|
||||
return []string{
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
}, nil
|
||||
func (p *geminiProvider) ListModels(ctx context.Context) ([]string, error) {
|
||||
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models?key=%s", p.apiKey)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("gemini list models error %d: %s", resp.StatusCode, raw)
|
||||
}
|
||||
var result struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
SupportedMethods []string `json:"supportedGenerationMethods"`
|
||||
} `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var names []string
|
||||
for _, m := range result.Models {
|
||||
for _, method := range m.SupportedMethods {
|
||||
if method == "generateContent" {
|
||||
// name is "models/gemini-xxx", strip prefix
|
||||
id := m.Name
|
||||
if len(id) > 7 {
|
||||
id = id[7:] // strip "models/"
|
||||
}
|
||||
names = append(names, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user