自动获取 model 列表

This commit is contained in:
ken 2025-03-30 13:42:37 +08:00
parent 456498d5ce
commit e60fa21a0f
7 changed files with 103 additions and 29 deletions

View File

@ -2,6 +2,7 @@ package agent
import ( import (
"agent/src/common" "agent/src/common"
"agent/src/llm"
"agent/src/utils/log" "agent/src/utils/log"
"agent/src/workflow" "agent/src/workflow"
"context" "context"
@ -19,11 +20,18 @@ type Agent struct {
} }
var once sync.Once var once sync.Once
func NewAgent(ctx context.Context) *Agent { func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
if ollamaUrls == nil {
ollamaUrls = []string{llm.DefaultOllamaUrl}
}
err := llm.InitLocalModel(ollamaUrls)
if err != nil {
return nil, err
}
return &Agent{ return &Agent{
ctx : ctx, ctx : ctx,
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false), clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
} }, nil
} }
func (a *Agent) Start(port uint64, crtDir *string) { func (a *Agent) Start(port uint64, crtDir *string) {

View File

@ -7,7 +7,7 @@ import (
func TestAgent(t *testing.T) { func TestAgent(t *testing.T) {
ctx := context.Background() ctx := context.Background()
a := NewAgent(ctx) a, _ := NewAgent(ctx, nil)
// crt := "../../config" // crt := "../../config"
a.Start(8080, nil) a.Start(8080, nil)
<-make(chan bool) <-make(chan bool)

View File

@ -1,7 +1,9 @@
package llm package llm
import ( import (
"agent/src/utils/log"
"context" "context"
"strings"
"github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/llms/ollama"
@ -10,33 +12,51 @@ import (
var LLMGroups map[string]*LLMGroup var LLMGroups map[string]*LLMGroup
var LLMNames []string var LLMNames []string
// const ollamaurl = "http://192.168.1.5:11434" // const DefaultOllamaUrl = "http://192.168.1.5:11434"
const ollamaurl = "http://ishangsf.com:11434" const DefaultOllamaUrl = "http://ishangsf.com:11434"
func init() { func InitLocalModel(urls []string) error {
models := []string{"qwen2.5-coder:7b", "deepseek-coder-v2:16b"} model2url := make(map[string][]string)
LLMNames = make([]string, len(models)) for _, url := range urls {
LLMGroups = make(map[string]*LLMGroup, len(models)) models, err := listOllamaModels(url)
for i, model := range models {
llm, err := ollama.New(ollama.WithServerURL(ollamaurl), ollama.WithModel(model))
if err != nil { if err != nil {
panic(err) return err
} }
for _, m := range models {
model2url[m] = append(model2url[m], url)
}
}
LLMNames = make([]string, 0, len(model2url))
LLMGroups = make(map[string]*LLMGroup, len(model2url))
for model, urls := range model2url {
model = "local " + model model = "local " + model
g := &LLMGroup{ g := &LLMGroup{
name: model, name: model,
llm: make(chan llms.Model, 1), llm: make(chan llms.Model, len(urls)),
local: true, local: true,
} }
g.llm <- llm for _, url := range urls {
llm, err := ollama.New(ollama.WithServerURL(url), ollama.WithModel(model))
if err != nil {
return err
}
g.llm <- llm
}
LLMGroups[model] = g LLMGroups[model] = g
LLMNames[i] = model LLMNames = append(LLMNames, model)
} }
log.Info("load models", "models", strings.Join(LLMNames, ","))
return nil
} }
/*
多个LLM组成一个集群
*/
type LLMGroup struct { type LLMGroup struct {
name string name string
llm chan llms.Model llm chan llms.Model

View File

@ -1,6 +1,7 @@
package llm package llm
import ( import (
"agent/src/utils"
"context" "context"
"fmt" "fmt"
"testing" "testing"
@ -27,3 +28,10 @@ func TestOllama(t *testing.T) {
})) }))
} }
func TestInitLocalModel(t *testing.T) {
err := InitLocalModel([]string{"http://ishangsf.com:11434"})
if err != nil {
panic(err)
}
utils.PrintJson(LLMNames)
}

39
src/llm/ollama.go Normal file
View File

@ -0,0 +1,39 @@
package llm
import (
"encoding/json"
"io"
"net/http"
)
type ollamaListModels struct {
Models []ollamaModel
}
type ollamaModel struct {
Model string
}
func listOllamaModels(url string) ([]string, error) {
res, err := http.Get(url+"/api/tags")
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var m ollamaListModels
err = json.Unmarshal(resBody, &m)
if err != nil {
return nil, err
}
models := make([]string, len(m.Models))
for i, m := range m.Models {
models[i] = m.Model
}
return models, nil
}

View File

@ -2,9 +2,9 @@ package main
import ( import (
"agent/src/agent" "agent/src/agent"
"agent/src/llm"
"agent/src/utils/log" "agent/src/utils/log"
"context" "context"
"fmt"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@ -18,8 +18,9 @@ var (
main_data = struct { main_data = struct {
port uint64 port uint64
crtDir string crtDir string
serveDocs string ollamaUrls []string
saveRequestTo string saveRequestTo string
}{} }{}
) )
@ -29,6 +30,7 @@ func init() {
f := cmd_run.Flags() f := cmd_run.Flags()
f.Uint64Var(&main_data.port, "port", 8080, "http port") f.Uint64Var(&main_data.port, "port", 8080, "http port")
f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)") f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)")
f.StringSliceVar(&main_data.ollamaUrls, "ollama_urls", []string{llm.DefaultOllamaUrl}, "ollama urls")
// TODO // TODO
// f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)") // f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)")
@ -44,18 +46,15 @@ var cmd_run = &cobra.Command{
Short: "run agent service", Short: "run agent service",
Example: "agent run", Example: "agent run",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
var saveRequestTo, crtDir *string var crtDir *string
if len(main_data.saveRequestTo) > 0 {
saveRequestTo = &main_data.saveRequestTo
}
if len(main_data.crtDir) > 0 { if len(main_data.crtDir) > 0 {
crtDir = &main_data.crtDir crtDir = &main_data.crtDir
} }
fmt.Println(main_data.saveRequestTo) a, err := agent.NewAgent(cmd.Context(), main_data.ollamaUrls)
a := agent.NewAgent(cmd.Context()) if err != nil {
return err
}
a.Start(main_data.port, crtDir) a.Start(main_data.port, crtDir)
_ = saveRequestTo
return nil return nil
}, },
} }

View File

@ -20,7 +20,7 @@ type TaskPool struct {
ctx context.Context ctx context.Context
queue *taskQueue queue *taskQueue
llmGroup *llm.LLMGroup llmGroup *llm.LLMGroup
sem chan struct{} // task并发数量控制 sem chan struct{} // 控制task并发数量
} }
func InitTaskPool(ctx context.Context) { func InitTaskPool(ctx context.Context) {
@ -90,7 +90,7 @@ func (t *TaskPool) loop() {
case <-task.ctx.Done(): case <-task.ctx.Done():
task.pool.remove(task.client) task.pool.remove(task.client)
case t.sem <- struct{}{}: case t.sem <- struct{}{}:
task.do() go task.do()
} }
} }