diff --git a/src/agent/agent.go b/src/agent/agent.go index 8bdf8fa..00ec4a7 100644 --- a/src/agent/agent.go +++ b/src/agent/agent.go @@ -2,6 +2,7 @@ package agent import ( "agent/src/common" + "agent/src/llm" "agent/src/utils/log" "agent/src/workflow" "context" @@ -19,11 +20,18 @@ type Agent struct { } var once sync.Once -func NewAgent(ctx context.Context) *Agent { +func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) { + if ollamaUrls == nil { + ollamaUrls = []string{llm.DefaultOllamaUrl} + } + err := llm.InitLocalModel(ollamaUrls) + if err != nil { + return nil, err + } return &Agent{ ctx : ctx, clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false), - } + }, nil } func (a *Agent) Start(port uint64, crtDir *string) { diff --git a/src/agent/agent_test.go b/src/agent/agent_test.go index 277f331..07b582d 100644 --- a/src/agent/agent_test.go +++ b/src/agent/agent_test.go @@ -7,7 +7,7 @@ import ( func TestAgent(t *testing.T) { ctx := context.Background() - a := NewAgent(ctx) + a, _ := NewAgent(ctx, nil) // crt := "../../config" a.Start(8080, nil) <-make(chan bool) diff --git a/src/llm/llm.go b/src/llm/llm.go index c0ee71d..91531f1 100644 --- a/src/llm/llm.go +++ b/src/llm/llm.go @@ -1,7 +1,9 @@ package llm import ( + "agent/src/utils/log" "context" + "strings" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/ollama" @@ -10,33 +12,51 @@ import ( var LLMGroups map[string]*LLMGroup var LLMNames []string -// const ollamaurl = "http://192.168.1.5:11434" -const ollamaurl = "http://ishangsf.com:11434" +// const DefaultOllamaUrl = "http://192.168.1.5:11434" +const DefaultOllamaUrl = "http://ishangsf.com:11434" -func init() { - models := []string{"qwen2.5-coder:7b", "deepseek-coder-v2:16b"} - LLMNames = make([]string, len(models)) - LLMGroups = make(map[string]*LLMGroup, len(models)) - - for i, model := range models { - llm, err := ollama.New(ollama.WithServerURL(ollamaurl), ollama.WithModel(model)) +func InitLocalModel(urls []string) error { + model2url := make(map[string][]string) + for _, url := range urls { + models, err := listOllamaModels(url) if err != nil { - panic(err) + return err } + for _, m := range models { + model2url[m] = append(model2url[m], url) + } + } + + LLMNames = make([]string, 0, len(model2url)) + LLMGroups = make(map[string]*LLMGroup, len(model2url)) + + for model, urls := range model2url { model = "local " + model g := &LLMGroup{ name: model, - llm: make(chan llms.Model, 1), + llm: make(chan llms.Model, len(urls)), local: true, } - g.llm <- llm + for _, url := range urls { + llm, err := ollama.New(ollama.WithServerURL(url), ollama.WithModel(model)) + if err != nil { + return err + } + g.llm <- llm + + } + LLMGroups[model] = g - LLMNames[i] = model + LLMNames = append(LLMNames, model) } + log.Info("load models", "models", strings.Join(LLMNames, ",")) + return nil } - +/* +多个LLM组成一个集群 +*/ type LLMGroup struct { name string llm chan llms.Model diff --git a/src/llm/llm_test.go b/src/llm/llm_test.go index 67ac4b5..c6f931a 100644 --- a/src/llm/llm_test.go +++ b/src/llm/llm_test.go @@ -1,6 +1,7 @@ package llm import ( + "agent/src/utils" "context" "fmt" "testing" @@ -27,3 +28,10 @@ func TestOllama(t *testing.T) { })) } +func TestInitLocalModel(t *testing.T) { + err := InitLocalModel([]string{"http://ishangsf.com:11434"}) + if err != nil { + panic(err) + } + utils.PrintJson(LLMNames) +} \ No newline at end of file diff --git a/src/llm/ollama.go b/src/llm/ollama.go new file mode 100644 index 0000000..61e4b6d --- /dev/null +++ b/src/llm/ollama.go @@ -0,0 +1,39 @@ +package llm + +import ( + "encoding/json" + "io" + "net/http" +) + +type ollamaListModels struct { + Models []ollamaModel +} + +type ollamaModel struct { + Model string +} + +func listOllamaModels(url string) ([]string, error) { + res, err := http.Get(url+"/api/tags") + if err != nil { + return nil, err + } + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var m ollamaListModels + err = json.Unmarshal(resBody, &m) + if err != nil { + return nil, err + } + + models := make([]string, len(m.Models)) + for i, m := range m.Models { + models[i] = m.Model + } + return models, nil + } + + diff --git a/src/main.go b/src/main.go index 4a070c6..f30bceb 100644 --- a/src/main.go +++ b/src/main.go @@ -2,9 +2,9 @@ package main import ( "agent/src/agent" + "agent/src/llm" "agent/src/utils/log" "context" - "fmt" "os" "os/signal" "syscall" @@ -18,8 +18,9 @@ var ( main_data = struct { port uint64 crtDir string - serveDocs string + ollamaUrls []string saveRequestTo string + }{} ) @@ -29,6 +30,7 @@ func init() { f := cmd_run.Flags() f.Uint64Var(&main_data.port, "port", 8080, "http port") f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)") + f.StringSliceVar(&main_data.ollamaUrls, "ollama_urls", []string{llm.DefaultOllamaUrl}, "ollama urls") // TODO // f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)") @@ -44,18 +46,15 @@ var cmd_run = &cobra.Command{ Short: "run agent service", Example: "agent run", RunE: func(cmd *cobra.Command, args []string) error { - var saveRequestTo, crtDir *string - if len(main_data.saveRequestTo) > 0 { - saveRequestTo = &main_data.saveRequestTo - } + var crtDir *string if len(main_data.crtDir) > 0 { crtDir = &main_data.crtDir } - fmt.Println(main_data.saveRequestTo) - a := agent.NewAgent(cmd.Context()) + a, err := agent.NewAgent(cmd.Context(), main_data.ollamaUrls) + if err != nil { + return err + } a.Start(main_data.port, crtDir) - - _ = saveRequestTo return nil }, } diff --git a/src/workflow/taskPool.go b/src/workflow/taskPool.go index e2a8dc1..a0caa35 100644 --- a/src/workflow/taskPool.go +++ b/src/workflow/taskPool.go @@ -20,7 +20,7 @@ type TaskPool struct { ctx context.Context queue *taskQueue llmGroup *llm.LLMGroup - sem chan struct{} // task并发数量控制 + sem chan struct{} // 控制task并发数量 } func InitTaskPool(ctx context.Context) { @@ -90,7 +90,7 @@ func (t *TaskPool) loop() { case <-task.ctx.Done(): task.pool.remove(task.client) case t.sem <- struct{}{}: - task.do() + go task.do() } }