SSE 改造 调用workflow.Handle方法
This commit is contained in:
parent
67cfd763c3
commit
a002e598f5
@ -1,6 +1,8 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"agent/src/message"
|
||||||
|
"agent/src/workflow"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -17,21 +19,25 @@ const (
|
|||||||
writeWait = 15 * time.Second
|
writeWait = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// client 用于管理每个与客户端的连接
|
||||||
type client struct {
|
type client struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
writer http.ResponseWriter
|
writer http.ResponseWriter
|
||||||
flusher http.Flusher
|
flusher http.Flusher
|
||||||
outMsg chan []byte
|
outMsg chan []byte
|
||||||
authenticated bool
|
authenticated bool
|
||||||
apikey string
|
apikey string
|
||||||
agent *Agent
|
agent *Agent
|
||||||
fnCancel context.CancelFunc
|
fnCancel context.CancelFunc
|
||||||
rate *rate.Limiter
|
rate *rate.Limiter
|
||||||
|
wantResponseId uint64 // 新增字段:存储当前请求的ID
|
||||||
|
wantResponse chan []byte // 新增字段:用于存储响应的通道
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
once sync.Once
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建新的客户端对象
|
||||||
func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
|
func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
|
||||||
ctx, fnCancel := context.WithCancel(a.ctx)
|
ctx, fnCancel := context.WithCancel(a.ctx)
|
||||||
return &client{
|
return &client{
|
||||||
@ -45,15 +51,18 @@ func newClient(w http.ResponseWriter, f http.Flusher, a *Agent) *client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置API Key并标记为已认证
|
||||||
func (c *client) setApiKey(apikey string) {
|
func (c *client) setApiKey(apikey string) {
|
||||||
c.apikey = apikey
|
c.apikey = apikey
|
||||||
c.authenticated = true
|
c.authenticated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 启动客户端的处理流程
|
||||||
func (c *client) run() {
|
func (c *client) run() {
|
||||||
go c.writePump()
|
go c.writePump()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writePump 处理SSE的消息推送
|
||||||
func (c *client) writePump() {
|
func (c *client) writePump() {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
@ -64,6 +73,7 @@ func (c *client) writePump() {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 向客户端发送消息,SSE格式
|
||||||
c.writer.Write([]byte("data: "))
|
c.writer.Write([]byte("data: "))
|
||||||
c.writer.Write(message)
|
c.writer.Write(message)
|
||||||
c.writer.Write([]byte("\n\n"))
|
c.writer.Write([]byte("\n\n"))
|
||||||
@ -74,14 +84,16 @@ func (c *client) writePump() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 关闭客户端连接
|
||||||
func (c *client) Close() {
|
func (c *client) Close() {
|
||||||
c.once.Do(func() {
|
c.once.Do(func() {
|
||||||
c.fnCancel()
|
c.fnCancel()
|
||||||
c.agent.removeClient(c)
|
c.agent.removeClient(c)
|
||||||
log.Info("client close", "connected", c.agent.clients.Size())
|
log.Info("client closed", "connected", c.agent.clients.Size())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将JSON数据发送到客户端
|
||||||
func (c *client) WriteJson(data interface{}, block bool) error {
|
func (c *client) WriteJson(data interface{}, block bool) error {
|
||||||
b, err := json.Marshal(data)
|
b, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,6 +102,7 @@ func (c *client) WriteJson(data interface{}, block bool) error {
|
|||||||
return c.WriteText(string(b), block)
|
return c.WriteText(string(b), block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将文本消息发送到客户端
|
||||||
func (c *client) WriteText(msg string, block bool) error {
|
func (c *client) WriteText(msg string, block bool) error {
|
||||||
if block {
|
if block {
|
||||||
select {
|
select {
|
||||||
@ -107,14 +120,59 @@ func (c *client) WriteText(msg string, block bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 发送纯文本消息
|
||||||
func (c *client) SendText(msg string) error {
|
func (c *client) SendText(msg string) error {
|
||||||
return c.WriteText(msg, true)
|
return c.WriteText(msg, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取客户端的上下文
|
||||||
func (c *client) GetCtx() context.Context {
|
func (c *client) GetCtx() context.Context {
|
||||||
return c.ctx
|
return c.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取客户端的消息通道
|
||||||
func (c *client) GetChannel() chan<- []byte {
|
func (c *client) GetChannel() chan<- []byte {
|
||||||
return c.outMsg
|
return c.outMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理接收到的消息请求
|
||||||
|
func (c *client) processSSERequest(msg []byte) {
|
||||||
|
// 检查速率限制
|
||||||
|
if !c.rate.Allow() {
|
||||||
|
err := c.WriteJson(&message.ReponseError{
|
||||||
|
Error: "maximum 5 requests per second",
|
||||||
|
}, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to send rate limit response", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用workflow.Handle来处理请求
|
||||||
|
err := workflow.Handle(c, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to handle workflow", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置希望接收响应的请求ID及对应的通道
|
||||||
|
func (c *client) WantResponse(requestId uint64, ch chan []byte) {
|
||||||
|
c.lock.Lock()
|
||||||
|
c.wantResponseId, c.wantResponse = requestId, ch
|
||||||
|
c.lock.Unlock()
|
||||||
|
// 确保在设置 wantResponse 之后关闭旧的通道
|
||||||
|
close(c.wantResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据请求ID发送响应消息
|
||||||
|
func (c *client) NewWantResponse(requestId uint64, msg []byte) error {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
// 检查请求ID是否匹配,确保没有不匹配的消息
|
||||||
|
if requestId != c.wantResponseId || c.wantResponse == nil {
|
||||||
|
return errors.New("unexpected msg")
|
||||||
|
}
|
||||||
|
// 将消息发送到对应的通道
|
||||||
|
c.wantResponse <- msg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user