Files
ai_job_chat_agent/internal/client/policy_client.go
2026-01-12 11:33:43 +08:00

257 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package client
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"qd-sc/internal/config"
"qd-sc/internal/model"
"strings"
"sync"
"time"
)
// PolicyClient 政策大模型客户端
type PolicyClient struct {
baseURL string
loginName string
userKey string
serviceID string
httpClient *http.Client
// ticket管理
mu sync.RWMutex
currentTicket *model.PolicyTicketData
ticketExpiresAt time.Time
}
// NewPolicyClient 创建政策大模型客户端
func NewPolicyClient(cfg *config.Config) *PolicyClient {
return &PolicyClient{
baseURL: cfg.Policy.BaseURL,
loginName: cfg.Policy.LoginName,
userKey: cfg.Policy.UserKey,
serviceID: cfg.Policy.ServiceID,
httpClient: NewHTTPClient(HTTPClientConfig{
Timeout: cfg.Policy.Timeout,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 50,
MaxConnsPerHost: 0,
}),
}
}
// GetTicket 获取ticket会自动缓存和刷新
func (c *PolicyClient) GetTicket() (*model.PolicyTicketData, error) {
c.mu.RLock()
// 如果ticket存在且未过期提前5分钟刷新
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
ticket := c.currentTicket
c.mu.RUnlock()
return ticket, nil
}
c.mu.RUnlock()
// 需要获取新ticket
c.mu.Lock()
defer c.mu.Unlock()
// 双重检查,防止并发重复请求
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
return c.currentTicket, nil
}
// 发起请求获取ticket
req := &model.PolicyTicketRequest{
LoginName: c.loginName,
UserKey: c.userKey,
}
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := fmt.Sprintf("%s/api/aiServer/getAccessUserInfo", c.baseURL)
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("HTTP请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
var result model.PolicyTicketResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if result.Code != 200 {
return nil, fmt.Errorf("获取ticket失败: %s", result.Message)
}
if result.Data == nil {
return nil, fmt.Errorf("返回数据为空")
}
// 缓存ticket设置过期时间为1小时后
c.currentTicket = result.Data
c.ticketExpiresAt = time.Now().Add(1 * time.Hour)
return c.currentTicket, nil
}
// Chat 发起政策咨询对话(非流式)
func (c *PolicyClient) Chat(chatReq *model.PolicyChatData) (*model.PolicyChatResponse, error) {
// 获取ticket
ticketData, err := c.GetTicket()
if err != nil {
return nil, fmt.Errorf("获取ticket失败: %w", err)
}
// 构造请求
req := &model.PolicyChatRequest{
AppID: ticketData.AppID,
Ticket: ticketData.Ticket,
Data: chatReq,
}
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("HTTP请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
var result model.PolicyChatResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if result.Code != 200 {
return nil, fmt.Errorf("对话请求失败: %s", result.Message)
}
return &result, nil
}
// ChatStream 发起政策咨询对话(流式)
func (c *PolicyClient) ChatStream(chatReq *model.PolicyChatData) (chan string, chan error, error) {
// 获取ticket
ticketData, err := c.GetTicket()
if err != nil {
return nil, nil, fmt.Errorf("获取ticket失败: %w", err)
}
// 设置流式模式
chatReq.Stream = true
// 构造请求
req := &model.PolicyChatRequest{
AppID: ticketData.AppID,
Ticket: ticketData.Ticket,
Data: chatReq,
}
reqBody, err := json.Marshal(req)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
contentChan := make(chan string, 100)
errChan := make(chan error, 1)
go func() {
defer resp.Body.Close()
defer close(contentChan)
defer close(errChan)
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
// 跳过空行
if strings.TrimSpace(line) == "" {
continue
}
// 尝试解析为JSON响应
var chunkResp model.PolicyChatResponse
if err := json.Unmarshal([]byte(line), &chunkResp); err != nil {
// 如果不是JSON格式可能是纯文本流直接发送
contentChan <- line
continue
}
// 检查响应码
if chunkResp.Code != 200 {
errChan <- fmt.Errorf("对话请求失败: %s", chunkResp.Message)
return
}
// 发送消息内容
if chunkResp.Data != nil && chunkResp.Data.Message != "" {
contentChan <- chunkResp.Data.Message
}
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("读取流失败: %w", err)
}
}()
return contentChan, errChan, nil
}