257 lines
6.4 KiB
Go
257 lines
6.4 KiB
Go
package client
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"qd-sc/internal/config"
|
||
"qd-sc/internal/model"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// PolicyClient 政策大模型客户端
|
||
type PolicyClient struct {
|
||
baseURL string
|
||
loginName string
|
||
userKey string
|
||
serviceID string
|
||
httpClient *http.Client
|
||
|
||
// ticket管理
|
||
mu sync.RWMutex
|
||
currentTicket *model.PolicyTicketData
|
||
ticketExpiresAt time.Time
|
||
}
|
||
|
||
// NewPolicyClient 创建政策大模型客户端
|
||
func NewPolicyClient(cfg *config.Config) *PolicyClient {
|
||
return &PolicyClient{
|
||
baseURL: cfg.Policy.BaseURL,
|
||
loginName: cfg.Policy.LoginName,
|
||
userKey: cfg.Policy.UserKey,
|
||
serviceID: cfg.Policy.ServiceID,
|
||
httpClient: NewHTTPClient(HTTPClientConfig{
|
||
Timeout: cfg.Policy.Timeout,
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 50,
|
||
MaxConnsPerHost: 0,
|
||
}),
|
||
}
|
||
}
|
||
|
||
// GetTicket 获取ticket(会自动缓存和刷新)
|
||
func (c *PolicyClient) GetTicket() (*model.PolicyTicketData, error) {
|
||
c.mu.RLock()
|
||
// 如果ticket存在且未过期(提前5分钟刷新)
|
||
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
||
ticket := c.currentTicket
|
||
c.mu.RUnlock()
|
||
return ticket, nil
|
||
}
|
||
c.mu.RUnlock()
|
||
|
||
// 需要获取新ticket
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
// 双重检查,防止并发重复请求
|
||
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
||
return c.currentTicket, nil
|
||
}
|
||
|
||
// 发起请求获取ticket
|
||
req := &model.PolicyTicketRequest{
|
||
LoginName: c.loginName,
|
||
UserKey: c.userKey,
|
||
}
|
||
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/aiServer/getAccessUserInfo", c.baseURL)
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := c.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var result model.PolicyTicketResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if result.Code != 200 {
|
||
return nil, fmt.Errorf("获取ticket失败: %s", result.Message)
|
||
}
|
||
|
||
if result.Data == nil {
|
||
return nil, fmt.Errorf("返回数据为空")
|
||
}
|
||
|
||
// 缓存ticket,设置过期时间为1小时后
|
||
c.currentTicket = result.Data
|
||
c.ticketExpiresAt = time.Now().Add(1 * time.Hour)
|
||
|
||
return c.currentTicket, nil
|
||
}
|
||
|
||
// Chat 发起政策咨询对话(非流式)
|
||
func (c *PolicyClient) Chat(chatReq *model.PolicyChatData) (*model.PolicyChatResponse, error) {
|
||
// 获取ticket
|
||
ticketData, err := c.GetTicket()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取ticket失败: %w", err)
|
||
}
|
||
|
||
// 构造请求
|
||
req := &model.PolicyChatRequest{
|
||
AppID: ticketData.AppID,
|
||
Ticket: ticketData.Ticket,
|
||
Data: chatReq,
|
||
}
|
||
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := c.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var result model.PolicyChatResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if result.Code != 200 {
|
||
return nil, fmt.Errorf("对话请求失败: %s", result.Message)
|
||
}
|
||
|
||
return &result, nil
|
||
}
|
||
|
||
// ChatStream 发起政策咨询对话(流式)
|
||
func (c *PolicyClient) ChatStream(chatReq *model.PolicyChatData) (chan string, chan error, error) {
|
||
// 获取ticket
|
||
ticketData, err := c.GetTicket()
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("获取ticket失败: %w", err)
|
||
}
|
||
|
||
// 设置流式模式
|
||
chatReq.Stream = true
|
||
|
||
// 构造请求
|
||
req := &model.PolicyChatRequest{
|
||
AppID: ticketData.AppID,
|
||
Ticket: ticketData.Ticket,
|
||
Data: chatReq,
|
||
}
|
||
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Accept", "text/event-stream")
|
||
|
||
resp, err := c.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
contentChan := make(chan string, 100)
|
||
errChan := make(chan error, 1)
|
||
|
||
go func() {
|
||
defer resp.Body.Close()
|
||
defer close(contentChan)
|
||
defer close(errChan)
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
|
||
// 跳过空行
|
||
if strings.TrimSpace(line) == "" {
|
||
continue
|
||
}
|
||
|
||
// 尝试解析为JSON响应
|
||
var chunkResp model.PolicyChatResponse
|
||
if err := json.Unmarshal([]byte(line), &chunkResp); err != nil {
|
||
// 如果不是JSON格式,可能是纯文本流,直接发送
|
||
contentChan <- line
|
||
continue
|
||
}
|
||
|
||
// 检查响应码
|
||
if chunkResp.Code != 200 {
|
||
errChan <- fmt.Errorf("对话请求失败: %s", chunkResp.Message)
|
||
return
|
||
}
|
||
|
||
// 发送消息内容
|
||
if chunkResp.Data != nil && chunkResp.Data.Message != "" {
|
||
contentChan <- chunkResp.Data.Message
|
||
}
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
errChan <- fmt.Errorf("读取流失败: %w", err)
|
||
}
|
||
}()
|
||
|
||
return contentChan, errChan, nil
|
||
}
|