257 lines
6.4 KiB
Go
257 lines
6.4 KiB
Go
|
|
package client
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bufio"
|
|||
|
|
"bytes"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"qd-sc/internal/config"
|
|||
|
|
"qd-sc/internal/model"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// PolicyClient 政策大模型客户端
|
|||
|
|
type PolicyClient struct {
|
|||
|
|
baseURL string
|
|||
|
|
loginName string
|
|||
|
|
userKey string
|
|||
|
|
serviceID string
|
|||
|
|
httpClient *http.Client
|
|||
|
|
|
|||
|
|
// ticket管理
|
|||
|
|
mu sync.RWMutex
|
|||
|
|
currentTicket *model.PolicyTicketData
|
|||
|
|
ticketExpiresAt time.Time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewPolicyClient 创建政策大模型客户端
|
|||
|
|
func NewPolicyClient(cfg *config.Config) *PolicyClient {
|
|||
|
|
return &PolicyClient{
|
|||
|
|
baseURL: cfg.Policy.BaseURL,
|
|||
|
|
loginName: cfg.Policy.LoginName,
|
|||
|
|
userKey: cfg.Policy.UserKey,
|
|||
|
|
serviceID: cfg.Policy.ServiceID,
|
|||
|
|
httpClient: NewHTTPClient(HTTPClientConfig{
|
|||
|
|
Timeout: cfg.Policy.Timeout,
|
|||
|
|
MaxIdleConns: 100,
|
|||
|
|
MaxIdleConnsPerHost: 50,
|
|||
|
|
MaxConnsPerHost: 0,
|
|||
|
|
}),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetTicket 获取ticket(会自动缓存和刷新)
|
|||
|
|
func (c *PolicyClient) GetTicket() (*model.PolicyTicketData, error) {
|
|||
|
|
c.mu.RLock()
|
|||
|
|
// 如果ticket存在且未过期(提前5分钟刷新)
|
|||
|
|
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
|||
|
|
ticket := c.currentTicket
|
|||
|
|
c.mu.RUnlock()
|
|||
|
|
return ticket, nil
|
|||
|
|
}
|
|||
|
|
c.mu.RUnlock()
|
|||
|
|
|
|||
|
|
// 需要获取新ticket
|
|||
|
|
c.mu.Lock()
|
|||
|
|
defer c.mu.Unlock()
|
|||
|
|
|
|||
|
|
// 双重检查,防止并发重复请求
|
|||
|
|
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
|||
|
|
return c.currentTicket, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 发起请求获取ticket
|
|||
|
|
req := &model.PolicyTicketRequest{
|
|||
|
|
LoginName: c.loginName,
|
|||
|
|
UserKey: c.userKey,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reqBody, err := json.Marshal(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
url := fmt.Sprintf("%s/api/aiServer/getAccessUserInfo", c.baseURL)
|
|||
|
|
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
|
|||
|
|
resp, err := c.httpClient.Do(httpReq)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var result model.PolicyTicketResponse
|
|||
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result.Code != 200 {
|
|||
|
|
return nil, fmt.Errorf("获取ticket失败: %s", result.Message)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result.Data == nil {
|
|||
|
|
return nil, fmt.Errorf("返回数据为空")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 缓存ticket,设置过期时间为1小时后
|
|||
|
|
c.currentTicket = result.Data
|
|||
|
|
c.ticketExpiresAt = time.Now().Add(1 * time.Hour)
|
|||
|
|
|
|||
|
|
return c.currentTicket, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Chat 发起政策咨询对话(非流式)
|
|||
|
|
func (c *PolicyClient) Chat(chatReq *model.PolicyChatData) (*model.PolicyChatResponse, error) {
|
|||
|
|
// 获取ticket
|
|||
|
|
ticketData, err := c.GetTicket()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取ticket失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 构造请求
|
|||
|
|
req := &model.PolicyChatRequest{
|
|||
|
|
AppID: ticketData.AppID,
|
|||
|
|
Ticket: ticketData.Ticket,
|
|||
|
|
Data: chatReq,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reqBody, err := json.Marshal(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
|||
|
|
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
|
|||
|
|
resp, err := c.httpClient.Do(httpReq)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var result model.PolicyChatResponse
|
|||
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if result.Code != 200 {
|
|||
|
|
return nil, fmt.Errorf("对话请求失败: %s", result.Message)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatStream 发起政策咨询对话(流式)
|
|||
|
|
func (c *PolicyClient) ChatStream(chatReq *model.PolicyChatData) (chan string, chan error, error) {
|
|||
|
|
// 获取ticket
|
|||
|
|
ticketData, err := c.GetTicket()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("获取ticket失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 设置流式模式
|
|||
|
|
chatReq.Stream = true
|
|||
|
|
|
|||
|
|
// 构造请求
|
|||
|
|
req := &model.PolicyChatRequest{
|
|||
|
|
AppID: ticketData.AppID,
|
|||
|
|
Ticket: ticketData.Ticket,
|
|||
|
|
Data: chatReq,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
reqBody, err := json.Marshal(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
|||
|
|
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
httpReq.Header.Set("Accept", "text/event-stream")
|
|||
|
|
|
|||
|
|
resp, err := c.httpClient.Do(httpReq)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
resp.Body.Close()
|
|||
|
|
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
contentChan := make(chan string, 100)
|
|||
|
|
errChan := make(chan error, 1)
|
|||
|
|
|
|||
|
|
go func() {
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
defer close(contentChan)
|
|||
|
|
defer close(errChan)
|
|||
|
|
|
|||
|
|
scanner := bufio.NewScanner(resp.Body)
|
|||
|
|
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
|||
|
|
|
|||
|
|
for scanner.Scan() {
|
|||
|
|
line := scanner.Text()
|
|||
|
|
|
|||
|
|
// 跳过空行
|
|||
|
|
if strings.TrimSpace(line) == "" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 尝试解析为JSON响应
|
|||
|
|
var chunkResp model.PolicyChatResponse
|
|||
|
|
if err := json.Unmarshal([]byte(line), &chunkResp); err != nil {
|
|||
|
|
// 如果不是JSON格式,可能是纯文本流,直接发送
|
|||
|
|
contentChan <- line
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查响应码
|
|||
|
|
if chunkResp.Code != 200 {
|
|||
|
|
errChan <- fmt.Errorf("对话请求失败: %s", chunkResp.Message)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 发送消息内容
|
|||
|
|
if chunkResp.Data != nil && chunkResp.Data.Message != "" {
|
|||
|
|
contentChan <- chunkResp.Data.Message
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := scanner.Err(); err != nil {
|
|||
|
|
errChan <- fmt.Errorf("读取流失败: %w", err)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
return contentChan, errChan, nil
|
|||
|
|
}
|