自动获取 model 列表
This commit is contained in:
parent
456498d5ce
commit
e60fa21a0f
@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"agent/src/common"
|
||||
"agent/src/llm"
|
||||
"agent/src/utils/log"
|
||||
"agent/src/workflow"
|
||||
"context"
|
||||
@ -19,11 +20,18 @@ type Agent struct {
|
||||
}
|
||||
var once sync.Once
|
||||
|
||||
func NewAgent(ctx context.Context) *Agent {
|
||||
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
||||
if ollamaUrls == nil {
|
||||
ollamaUrls = []string{llm.DefaultOllamaUrl}
|
||||
}
|
||||
err := llm.InitLocalModel(ollamaUrls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Agent{
|
||||
ctx : ctx,
|
||||
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Agent) Start(port uint64, crtDir *string) {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func TestAgent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
a := NewAgent(ctx)
|
||||
a, _ := NewAgent(ctx, nil)
|
||||
// crt := "../../config"
|
||||
a.Start(8080, nil)
|
||||
<-make(chan bool)
|
||||
|
@ -1,7 +1,9 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"agent/src/utils/log"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/ollama"
|
||||
@ -10,33 +12,51 @@ import (
|
||||
var LLMGroups map[string]*LLMGroup
|
||||
var LLMNames []string
|
||||
|
||||
// const ollamaurl = "http://192.168.1.5:11434"
|
||||
const ollamaurl = "http://ishangsf.com:11434"
|
||||
// const DefaultOllamaUrl = "http://192.168.1.5:11434"
|
||||
const DefaultOllamaUrl = "http://ishangsf.com:11434"
|
||||
|
||||
func init() {
|
||||
models := []string{"qwen2.5-coder:7b", "deepseek-coder-v2:16b"}
|
||||
LLMNames = make([]string, len(models))
|
||||
LLMGroups = make(map[string]*LLMGroup, len(models))
|
||||
|
||||
for i, model := range models {
|
||||
llm, err := ollama.New(ollama.WithServerURL(ollamaurl), ollama.WithModel(model))
|
||||
func InitLocalModel(urls []string) error {
|
||||
model2url := make(map[string][]string)
|
||||
for _, url := range urls {
|
||||
models, err := listOllamaModels(url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
for _, m := range models {
|
||||
model2url[m] = append(model2url[m], url)
|
||||
}
|
||||
}
|
||||
|
||||
LLMNames = make([]string, 0, len(model2url))
|
||||
LLMGroups = make(map[string]*LLMGroup, len(model2url))
|
||||
|
||||
for model, urls := range model2url {
|
||||
model = "local " + model
|
||||
g := &LLMGroup{
|
||||
name: model,
|
||||
llm: make(chan llms.Model, 1),
|
||||
llm: make(chan llms.Model, len(urls)),
|
||||
local: true,
|
||||
}
|
||||
for _, url := range urls {
|
||||
llm, err := ollama.New(ollama.WithServerURL(url), ollama.WithModel(model))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.llm <- llm
|
||||
|
||||
}
|
||||
|
||||
LLMGroups[model] = g
|
||||
LLMNames[i] = model
|
||||
LLMNames = append(LLMNames, model)
|
||||
}
|
||||
log.Info("load models", "models", strings.Join(LLMNames, ","))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
多个LLM组成一个集群
|
||||
*/
|
||||
type LLMGroup struct {
|
||||
name string
|
||||
llm chan llms.Model
|
||||
|
@ -1,6 +1,7 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"agent/src/utils"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
@ -27,3 +28,10 @@ func TestOllama(t *testing.T) {
|
||||
}))
|
||||
}
|
||||
|
||||
func TestInitLocalModel(t *testing.T) {
|
||||
err := InitLocalModel([]string{"http://ishangsf.com:11434"})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
utils.PrintJson(LLMNames)
|
||||
}
|
39
src/llm/ollama.go
Normal file
39
src/llm/ollama.go
Normal file
@ -0,0 +1,39 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ollamaListModels struct {
|
||||
Models []ollamaModel
|
||||
}
|
||||
|
||||
type ollamaModel struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
func listOllamaModels(url string) ([]string, error) {
|
||||
res, err := http.Get(url+"/api/tags")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resBody, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var m ollamaListModels
|
||||
err = json.Unmarshal(resBody, &m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
models := make([]string, len(m.Models))
|
||||
for i, m := range m.Models {
|
||||
models[i] = m.Model
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
|
19
src/main.go
19
src/main.go
@ -2,9 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"agent/src/agent"
|
||||
"agent/src/llm"
|
||||
"agent/src/utils/log"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@ -18,8 +18,9 @@ var (
|
||||
main_data = struct {
|
||||
port uint64
|
||||
crtDir string
|
||||
serveDocs string
|
||||
ollamaUrls []string
|
||||
saveRequestTo string
|
||||
|
||||
}{}
|
||||
|
||||
)
|
||||
@ -29,6 +30,7 @@ func init() {
|
||||
f := cmd_run.Flags()
|
||||
f.Uint64Var(&main_data.port, "port", 8080, "http port")
|
||||
f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)")
|
||||
f.StringSliceVar(&main_data.ollamaUrls, "ollama_urls", []string{llm.DefaultOllamaUrl}, "ollama urls")
|
||||
|
||||
// TODO
|
||||
// f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)")
|
||||
@ -44,18 +46,15 @@ var cmd_run = &cobra.Command{
|
||||
Short: "run agent service",
|
||||
Example: "agent run",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var saveRequestTo, crtDir *string
|
||||
if len(main_data.saveRequestTo) > 0 {
|
||||
saveRequestTo = &main_data.saveRequestTo
|
||||
}
|
||||
var crtDir *string
|
||||
if len(main_data.crtDir) > 0 {
|
||||
crtDir = &main_data.crtDir
|
||||
}
|
||||
fmt.Println(main_data.saveRequestTo)
|
||||
a := agent.NewAgent(cmd.Context())
|
||||
a, err := agent.NewAgent(cmd.Context(), main_data.ollamaUrls)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.Start(main_data.port, crtDir)
|
||||
|
||||
_ = saveRequestTo
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ type TaskPool struct {
|
||||
ctx context.Context
|
||||
queue *taskQueue
|
||||
llmGroup *llm.LLMGroup
|
||||
sem chan struct{} // task并发数量控制
|
||||
sem chan struct{} // 控制task并发数量
|
||||
}
|
||||
|
||||
func InitTaskPool(ctx context.Context) {
|
||||
@ -90,7 +90,7 @@ func (t *TaskPool) loop() {
|
||||
case <-task.ctx.Done():
|
||||
task.pool.remove(task.client)
|
||||
case t.sem <- struct{}{}:
|
||||
task.do()
|
||||
go task.do()
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user