自动获取 model 列表
This commit is contained in:
parent
456498d5ce
commit
e60fa21a0f
@ -2,6 +2,7 @@ package agent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"agent/src/common"
|
"agent/src/common"
|
||||||
|
"agent/src/llm"
|
||||||
"agent/src/utils/log"
|
"agent/src/utils/log"
|
||||||
"agent/src/workflow"
|
"agent/src/workflow"
|
||||||
"context"
|
"context"
|
||||||
@ -19,11 +20,18 @@ type Agent struct {
|
|||||||
}
|
}
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|
||||||
func NewAgent(ctx context.Context) *Agent {
|
func NewAgent(ctx context.Context, ollamaUrls []string) (*Agent, error) {
|
||||||
|
if ollamaUrls == nil {
|
||||||
|
ollamaUrls = []string{llm.DefaultOllamaUrl}
|
||||||
|
}
|
||||||
|
err := llm.InitLocalModel(ollamaUrls)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &Agent{
|
return &Agent{
|
||||||
ctx : ctx,
|
ctx : ctx,
|
||||||
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
|
clients: new(common.ThreadSafeMap[*client, struct{}]).Init(nil, false),
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) Start(port uint64, crtDir *string) {
|
func (a *Agent) Start(port uint64, crtDir *string) {
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
func TestAgent(t *testing.T) {
|
func TestAgent(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
a := NewAgent(ctx)
|
a, _ := NewAgent(ctx, nil)
|
||||||
// crt := "../../config"
|
// crt := "../../config"
|
||||||
a.Start(8080, nil)
|
a.Start(8080, nil)
|
||||||
<-make(chan bool)
|
<-make(chan bool)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"agent/src/utils/log"
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/tmc/langchaingo/llms"
|
"github.com/tmc/langchaingo/llms"
|
||||||
"github.com/tmc/langchaingo/llms/ollama"
|
"github.com/tmc/langchaingo/llms/ollama"
|
||||||
@ -10,33 +12,51 @@ import (
|
|||||||
var LLMGroups map[string]*LLMGroup
|
var LLMGroups map[string]*LLMGroup
|
||||||
var LLMNames []string
|
var LLMNames []string
|
||||||
|
|
||||||
// const ollamaurl = "http://192.168.1.5:11434"
|
// const DefaultOllamaUrl = "http://192.168.1.5:11434"
|
||||||
const ollamaurl = "http://ishangsf.com:11434"
|
const DefaultOllamaUrl = "http://ishangsf.com:11434"
|
||||||
|
|
||||||
func init() {
|
func InitLocalModel(urls []string) error {
|
||||||
models := []string{"qwen2.5-coder:7b", "deepseek-coder-v2:16b"}
|
model2url := make(map[string][]string)
|
||||||
LLMNames = make([]string, len(models))
|
for _, url := range urls {
|
||||||
LLMGroups = make(map[string]*LLMGroup, len(models))
|
models, err := listOllamaModels(url)
|
||||||
|
|
||||||
for i, model := range models {
|
|
||||||
llm, err := ollama.New(ollama.WithServerURL(ollamaurl), ollama.WithModel(model))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
|
for _, m := range models {
|
||||||
|
model2url[m] = append(model2url[m], url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LLMNames = make([]string, 0, len(model2url))
|
||||||
|
LLMGroups = make(map[string]*LLMGroup, len(model2url))
|
||||||
|
|
||||||
|
for model, urls := range model2url {
|
||||||
model = "local " + model
|
model = "local " + model
|
||||||
g := &LLMGroup{
|
g := &LLMGroup{
|
||||||
name: model,
|
name: model,
|
||||||
llm: make(chan llms.Model, 1),
|
llm: make(chan llms.Model, len(urls)),
|
||||||
local: true,
|
local: true,
|
||||||
}
|
}
|
||||||
g.llm <- llm
|
for _, url := range urls {
|
||||||
|
llm, err := ollama.New(ollama.WithServerURL(url), ollama.WithModel(model))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
g.llm <- llm
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
LLMGroups[model] = g
|
LLMGroups[model] = g
|
||||||
LLMNames[i] = model
|
LLMNames = append(LLMNames, model)
|
||||||
}
|
}
|
||||||
|
log.Info("load models", "models", strings.Join(LLMNames, ","))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
多个LLM组成一个集群
|
||||||
|
*/
|
||||||
type LLMGroup struct {
|
type LLMGroup struct {
|
||||||
name string
|
name string
|
||||||
llm chan llms.Model
|
llm chan llms.Model
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"agent/src/utils"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
@ -27,3 +28,10 @@ func TestOllama(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInitLocalModel(t *testing.T) {
|
||||||
|
err := InitLocalModel([]string{"http://ishangsf.com:11434"})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
utils.PrintJson(LLMNames)
|
||||||
|
}
|
39
src/llm/ollama.go
Normal file
39
src/llm/ollama.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ollamaListModels struct {
|
||||||
|
Models []ollamaModel
|
||||||
|
}
|
||||||
|
|
||||||
|
type ollamaModel struct {
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func listOllamaModels(url string) ([]string, error) {
|
||||||
|
res, err := http.Get(url+"/api/tags")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resBody, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var m ollamaListModels
|
||||||
|
err = json.Unmarshal(resBody, &m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, len(m.Models))
|
||||||
|
for i, m := range m.Models {
|
||||||
|
models[i] = m.Model
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
19
src/main.go
19
src/main.go
@ -2,9 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"agent/src/agent"
|
"agent/src/agent"
|
||||||
|
"agent/src/llm"
|
||||||
"agent/src/utils/log"
|
"agent/src/utils/log"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -18,8 +18,9 @@ var (
|
|||||||
main_data = struct {
|
main_data = struct {
|
||||||
port uint64
|
port uint64
|
||||||
crtDir string
|
crtDir string
|
||||||
serveDocs string
|
ollamaUrls []string
|
||||||
saveRequestTo string
|
saveRequestTo string
|
||||||
|
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
)
|
)
|
||||||
@ -29,6 +30,7 @@ func init() {
|
|||||||
f := cmd_run.Flags()
|
f := cmd_run.Flags()
|
||||||
f.Uint64Var(&main_data.port, "port", 8080, "http port")
|
f.Uint64Var(&main_data.port, "port", 8080, "http port")
|
||||||
f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)")
|
f.StringVar(&main_data.crtDir, "crt", "", "where is server.crt, server.key (default no https)")
|
||||||
|
f.StringSliceVar(&main_data.ollamaUrls, "ollama_urls", []string{llm.DefaultOllamaUrl}, "ollama urls")
|
||||||
|
|
||||||
// TODO
|
// TODO
|
||||||
// f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)")
|
// f.StringVar(&main_data.saveRequestTo, "save_request_to", "", "save request to dir (default not save)")
|
||||||
@ -44,18 +46,15 @@ var cmd_run = &cobra.Command{
|
|||||||
Short: "run agent service",
|
Short: "run agent service",
|
||||||
Example: "agent run",
|
Example: "agent run",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
var saveRequestTo, crtDir *string
|
var crtDir *string
|
||||||
if len(main_data.saveRequestTo) > 0 {
|
|
||||||
saveRequestTo = &main_data.saveRequestTo
|
|
||||||
}
|
|
||||||
if len(main_data.crtDir) > 0 {
|
if len(main_data.crtDir) > 0 {
|
||||||
crtDir = &main_data.crtDir
|
crtDir = &main_data.crtDir
|
||||||
}
|
}
|
||||||
fmt.Println(main_data.saveRequestTo)
|
a, err := agent.NewAgent(cmd.Context(), main_data.ollamaUrls)
|
||||||
a := agent.NewAgent(cmd.Context())
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
a.Start(main_data.port, crtDir)
|
a.Start(main_data.port, crtDir)
|
||||||
|
|
||||||
_ = saveRequestTo
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ type TaskPool struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
queue *taskQueue
|
queue *taskQueue
|
||||||
llmGroup *llm.LLMGroup
|
llmGroup *llm.LLMGroup
|
||||||
sem chan struct{} // task并发数量控制
|
sem chan struct{} // 控制task并发数量
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitTaskPool(ctx context.Context) {
|
func InitTaskPool(ctx context.Context) {
|
||||||
@ -90,7 +90,7 @@ func (t *TaskPool) loop() {
|
|||||||
case <-task.ctx.Done():
|
case <-task.ctx.Done():
|
||||||
task.pool.remove(task.client)
|
task.pool.remove(task.client)
|
||||||
case t.sem <- struct{}{}:
|
case t.sem <- struct{}{}:
|
||||||
task.do()
|
go task.do()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user