自动获取 model 列表

This commit is contained in:
ken 2025-03-30 13:42:37 +08:00
parent 456498d5ce
commit e60fa21a0f
7 changed files with 103 additions and 29 deletions

View File

@ -2,6 +2,7 @@ package agent
import (
"agent/src/common"
"agent/src/llm"
"agent/src/utils/log"
"agent/src/workflow"
"context"
@ -19,11 +20,18 @@ type Agent struct {
}
var once sync.Once
func NewAgent(ctx context.Context) *Agent {
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
if ollamaUrls == nil {
ollamaUrls = []string{llm.DefaultOllamaUrl}
}
err := llm.InitLocalModel(ollamaUrls)
if err != nil {
return nil, err
}
return &Agent{
ctx : ctx,
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
}
}, nil
}
func (a *Agent) Start(port uint64, crtDir *string) {

View File

@ -7,7 +7,7 @@ import (
func TestAgent(t *testing.T) {
ctx := context.Background()
a := NewAgent(ctx)
a, _ := NewAgent(ctx, nil)
// crt := "../../config"
a.Start(8080, nil)
<-make(chan bool)

View File

@ -1,7 +1,9 @@
package llm
import (
"agent/src/utils/log"
"context"
"strings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
@ -10,33 +12,51 @@ import (
var LLMGroups map[string]*LLMGroup
var LLMNames []string
// const ollamaurl = "http://192.168.1.5:11434"
const ollamaurl = "http://ishangsf.com:11434"
// const DefaultOllamaUrl = "http://192.168.1.5:11434"
const DefaultOllamaUrl = "http://ishangsf.com:11434"
func init() {
models := []string{"qwen2.5-coder:7b", "deepseek-coder-v2:16b"}
LLMNames = make([]string, len(models))
LLMGroups = make(map[string]*LLMGroup, len(models))
for i, model := range models {
llm, err := ollama.New(ollama.WithServerURL(ollamaurl), ollama.WithModel(model))
func InitLocalModel(urls []string) error {
model2url := make(map[string][]string)
for _, url := range urls {
models, err := listOllamaModels(url)
if err != nil {
panic(err)
return err
}
for _, m := range models {
model2url[m] = append(model2url[m], url)
}
}
LLMNames = make([]string, 0, len(model2url))
LLMGroups = make(map[string]*LLMGroup, len(model2url))
for model, urls := range model2url {
model = "local " + model
g := &LLMGroup{
name: model,
llm: make(chan llms.Model, 1),
llm: make(chan llms.Model, len(urls)),
local: true,
}
for _, url := range urls {
llm, err := ollama.New(ollama.WithServerURL(url), ollama.WithModel(model))
if err != nil {
return err
}
g.llm <- llm
}
LLMGroups[model] = g
LLMNames[i] = model
LLMNames = append(LLMNames, model)
}
log.Info("load models", "models", strings.Join(LLMNames, ","))
return nil
}
/*
多个LLM组成一个集群
*/
type LLMGroup struct {
name string
llm chan llms.Model

View File

@ -1,6 +1,7 @@
package llm
import (
"agent/src/utils"
"context"
"fmt"
"testing"
@ -27,3 +28,10 @@ func TestOllama(t *testing.T) {
}))
}
func TestInitLocalModel(t *testing.T) {
err := InitLocalModel([]string{"http://ishangsf.com:11434"})
if err != nil {
panic(err)
}
utils.PrintJson(LLMNames)
}

39
src/llm/ollama.go Normal file
View File

@ -0,0 +1,39 @@
package llm
import (
"encoding/json"
"io"
"net/http"
)
type ollamaListModels struct {
Models []ollamaModel
}
type ollamaModel struct {
Model string
}
func listOllamaModels(url string) ([]string, error) {
res, err := http.Get(url+"/api/tags")
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var m ollamaListModels
err = json.Unmarshal(resBody, &m)
if err != nil {
return nil, err
}
models := make([]string, len(m.Models))
for i, m := range m.Models {
models[i] = m.Model
}
return models, nil
}

View File

@ -2,9 +2,9 @@ package main
import (
"agent/src/agent"
"agent/src/llm"
"agent/src/utils/log"
"context"
"fmt"
"os"
"os/signal"
"syscall"
@ -18,8 +18,9 @@ var (
main_data = struct {
port uint64
crtDir string
serveDocs string
ollamaUrls []string
saveRequestTo string
}{}
)
@ -29,6 +30,7 @@ func init() {
f := cmd_run.Flags()
f.Uint64Var(&main_data.port, "port", 8080, "http port")
f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)")
f.StringSliceVar(&main_data.ollamaUrls, "ollama_urls", []string{llm.DefaultOllamaUrl}, "ollama urls")
// TODO
// f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)")
@ -44,18 +46,15 @@ var cmd_run = &cobra.Command{
Short: "run agent service",
Example: "agent run",
RunE: func(cmd *cobra.Command, args []string) error {
var saveRequestTo, crtDir *string
if len(main_data.saveRequestTo) > 0 {
saveRequestTo = &main_data.saveRequestTo
}
var crtDir *string
if len(main_data.crtDir) > 0 {
crtDir = &main_data.crtDir
}
fmt.Println(main_data.saveRequestTo)
a := agent.NewAgent(cmd.Context())
a, err := agent.NewAgent(cmd.Context(), main_data.ollamaUrls)
if err != nil {
return err
}
a.Start(main_data.port, crtDir)
_ = saveRequestTo
return nil
},
}

View File

@ -20,7 +20,7 @@ type TaskPool struct {
ctx context.Context
queue *taskQueue
llmGroup *llm.LLMGroup
sem chan struct{} // task并发数量控制
sem chan struct{} // 控制task并发数量
}
func InitTaskPool(ctx context.Context) {
@ -90,7 +90,7 @@ func (t *TaskPool) loop() {
case <-task.ctx.Done():
task.pool.remove(task.client)
case t.sem <- struct{}{}:
task.do()
go task.do()
}
}