SSE 改造 调用workflow.Handle方法

This commit is contained in:
xingchen 2025-04-27 15:38:12 +08:00
parent 67cfd763c3
commit a002e598f5

View File

@ -1,6 +1,8 @@
package agent package agent
import ( import (
"agent/src/message"
"agent/src/workflow"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@ -17,21 +19,25 @@ const (
writeWait = 15 * time.Second writeWait = 15 * time.Second
) )
// client 用于管理每个与客户端的连接
type client struct { type client struct {
ctx context.Context ctx context.Context
writer http.ResponseWriter writer http.ResponseWriter
flusher http.Flusher flusher http.Flusher
outMsg chan []byte outMsg chan []byte
authenticated bool authenticated bool
apikey string apikey string
agent *Agent agent *Agent
fnCancel context.CancelFunc fnCancel context.CancelFunc
rate *rate.Limiter rate *rate.Limiter
wantResponseId uint64 // 新增字段存储当前请求的ID
wantResponse chan []byte // 新增字段:用于存储响应的通道
lock sync.Mutex lock sync.Mutex
once sync.Once once sync.Once
} }
// 创建新的客户端对象
func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client { func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
ctx, fnCancel := context.WithCancel(a.ctx) ctx, fnCancel := context.WithCancel(a.ctx)
return &client{ return &client{
@ -45,15 +51,18 @@ func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
} }
} }
// 设置API Key并标记为已认证
func (c *client) setApiKey(apikey string) { func (c *client) setApiKey(apikey string) {
c.apikey = apikey c.apikey = apikey
c.authenticated = true c.authenticated = true
} }
// 启动客户端的处理流程
func (c *client) run() { func (c *client) run() {
go c.writePump() go c.writePump()
} }
// writePump 处理SSE的消息推送
func (c *client) writePump() { func (c *client) writePump() {
defer c.Close() defer c.Close()
@ -64,6 +73,7 @@ func (c *client) writePump() {
if !ok { if !ok {
return return
} }
// 向客户端发送消息SSE格式
c.writer.Write([]byte("data: ")) c.writer.Write([]byte("data: "))
c.writer.Write(message) c.writer.Write(message)
c.writer.Write([]byte("\n\n")) c.writer.Write([]byte("\n\n"))
@ -74,14 +84,16 @@ func (c *client) writePump() {
} }
} }
// 关闭客户端连接
func (c *client) Close() { func (c *client) Close() {
c.once.Do(func() { c.once.Do(func() {
c.fnCancel() c.fnCancel()
c.agent.removeClient(c) c.agent.removeClient(c)
log.Info("client close", "connected", c.agent.clients.Size()) log.Info("client closed", "connected", c.agent.clients.Size())
}) })
} }
// 将JSON数据发送到客户端
func (c *client) WriteJson(data interface{}, block bool) error { func (c *client) WriteJson(data interface{}, block bool) error {
b, err := json.Marshal(data) b, err := json.Marshal(data)
if err != nil { if err != nil {
@ -90,6 +102,7 @@ func (c *client) WriteJson(data interface{}, block bool) error {
return c.WriteText(string(b), block) return c.WriteText(string(b), block)
} }
// 将文本消息发送到客户端
func (c *client) WriteText(msg string, block bool) error { func (c *client) WriteText(msg string, block bool) error {
if block { if block {
select { select {
@ -107,14 +120,59 @@ func (c *client) WriteText(msg string, block bool) error {
return nil return nil
} }
// 发送纯文本消息
func (c *client) SendText(msg string) error { func (c *client) SendText(msg string) error {
return c.WriteText(msg, true) return c.WriteText(msg, true)
} }
// 获取客户端的上下文
func (c *client) GetCtx() context.Context { func (c *client) GetCtx() context.Context {
return c.ctx return c.ctx
} }
// 获取客户端的消息通道
func (c *client) GetChannel() chan<- []byte { func (c *client) GetChannel() chan<- []byte {
return c.outMsg return c.outMsg
} }
// 处理接收到的消息请求
func (c *client) processSSERequest(msg []byte) {
// 检查速率限制
if !c.rate.Allow() {
err := c.WriteJson(&message.ReponseError{
Error: "maximum 5 requests per second",
}, true)
if err != nil {
log.Error("Failed to send rate limit response", "error", err)
}
return
}
// 调用workflow.Handle来处理请求
err := workflow.Handle(c, msg)
if err != nil {
log.Error("Failed to handle workflow", "error", err)
}
}
// 设置希望接收响应的请求ID及对应的通道
func (c *client) WantResponse(requestId uint64, ch chan []byte) {
c.lock.Lock()
c.wantResponseId, c.wantResponse = requestId, ch
c.lock.Unlock()
// 确保在设置 wantResponse 之后关闭旧的通道
close(c.wantResponse)
}
// 根据请求ID发送响应消息
func (c *client) NewWantResponse(requestId uint64, msg []byte) error {
c.lock.Lock()
defer c.lock.Unlock()
// 检查请求ID是否匹配确保没有不匹配的消息
if requestId != c.wantResponseId || c.wantResponse == nil {
return errors.New("unexpected msg")
}
// 将消息发送到对应的通道
c.wantResponse <- msg
return nil
}