diff --git a/src/agent/client.go b/src/agent/client.go index ce338eb..fc483b9 100644 --- a/src/agent/client.go +++ b/src/agent/client.go @@ -1,6 +1,8 @@ package agent import ( + "agent/src/message" + "agent/src/workflow" "context" "encoding/json" "errors" @@ -17,21 +19,25 @@ const ( writeWait = 15 * time.Second ) +// client 用于管理每个与客户端的连接 type client struct { - ctx context.Context - writer http.ResponseWriter - flusher http.Flusher - outMsg chan []byte - authenticated bool - apikey string - agent *Agent - fnCancel context.CancelFunc - rate *rate.Limiter + ctx context.Context + writer http.ResponseWriter + flusher http.Flusher + outMsg chan []byte + authenticated bool + apikey string + agent *Agent + fnCancel context.CancelFunc + rate *rate.Limiter + wantResponseId uint64 // 新增字段:存储当前请求的ID + wantResponse chan []byte // 新增字段:用于存储响应的通道 lock sync.Mutex once sync.Once } +// 创建新的客户端对象 func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client { ctx, fnCancel := context.WithCancel(a.ctx) return &client{ @@ -45,15 +51,18 @@ func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client { } } +// 设置API Key并标记为已认证 func (c *client) setApiKey(apikey string) { c.apikey = apikey c.authenticated = true } +// 启动客户端的处理流程 func (c *client) run() { go c.writePump() } +// writePump 处理SSE的消息推送 func (c *client) writePump() { defer c.Close() @@ -64,6 +73,7 @@ func (c *client) writePump() { if !ok { return } + // 向客户端发送消息,SSE格式 c.writer.Write([]byte("data: ")) c.writer.Write(message) c.writer.Write([]byte("\n\n")) @@ -74,14 +84,16 @@ func (c *client) writePump() { } } +// 关闭客户端连接 func (c *client) Close() { c.once.Do(func() { c.fnCancel() c.agent.removeClient(c) - log.Info("client close", "connected", c.agent.clients.Size()) + log.Info("client closed", "connected", c.agent.clients.Size()) }) } +// 将JSON数据发送到客户端 func (c *client) WriteJson(data interface{}, block bool) error { b, err := json.Marshal(data) if err != nil { @@ -90,6 +102,7 @@ func (c *client) WriteJson(data interface{}, block bool) error { return c.WriteText(string(b), block) } +// 将文本消息发送到客户端 func (c *client) WriteText(msg string, block bool) error { if block { select { @@ -107,14 +120,59 @@ func (c *client) WriteText(msg string, block bool) error { return nil } +// 发送纯文本消息 func (c *client) SendText(msg string) error { return c.WriteText(msg, true) } +// 获取客户端的上下文 func (c *client) GetCtx() context.Context { return c.ctx } +// 获取客户端的消息通道 func (c *client) GetChannel() chan<- []byte { return c.outMsg } + +// 处理接收到的消息请求 +func (c *client) processSSERequest(msg []byte) { + // 检查速率限制 + if !c.rate.Allow() { + err := c.WriteJson(&message.ReponseError{ + Error: "maximum 5 requests per second", + }, true) + if err != nil { + log.Error("Failed to send rate limit response", "error", err) + } + return + } + + // 调用workflow.Handle来处理请求 + err := workflow.Handle(c, msg) + if err != nil { + log.Error("Failed to handle workflow", "error", err) + } +} + +// 设置希望接收响应的请求ID及对应的通道 +func (c *client) WantResponse(requestId uint64, ch chan []byte) { + c.lock.Lock() + c.wantResponseId, c.wantResponse = requestId, ch + c.lock.Unlock() + // 确保在设置 wantResponse 之后关闭旧的通道 + close(c.wantResponse) +} + +// 根据请求ID发送响应消息 +func (c *client) NewWantResponse(requestId uint64, msg []byte) error { + c.lock.Lock() + defer c.lock.Unlock() + // 检查请求ID是否匹配,确保没有不匹配的消息 + if requestId != c.wantResponseId || c.wantResponse == nil { + return errors.New("unexpected msg") + } + // 将消息发送到对应的通道 + c.wantResponse <- msg + return nil +}