init
This commit is contained in:
256
internal/client/policy_client.go
Normal file
256
internal/client/policy_client.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"qd-sc/internal/config"
|
||||
"qd-sc/internal/model"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PolicyClient 政策大模型客户端
|
||||
type PolicyClient struct {
|
||||
baseURL string
|
||||
loginName string
|
||||
userKey string
|
||||
serviceID string
|
||||
httpClient *http.Client
|
||||
|
||||
// ticket管理
|
||||
mu sync.RWMutex
|
||||
currentTicket *model.PolicyTicketData
|
||||
ticketExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewPolicyClient 创建政策大模型客户端
|
||||
func NewPolicyClient(cfg *config.Config) *PolicyClient {
|
||||
return &PolicyClient{
|
||||
baseURL: cfg.Policy.BaseURL,
|
||||
loginName: cfg.Policy.LoginName,
|
||||
userKey: cfg.Policy.UserKey,
|
||||
serviceID: cfg.Policy.ServiceID,
|
||||
httpClient: NewHTTPClient(HTTPClientConfig{
|
||||
Timeout: cfg.Policy.Timeout,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 50,
|
||||
MaxConnsPerHost: 0,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// GetTicket 获取ticket(会自动缓存和刷新)
|
||||
func (c *PolicyClient) GetTicket() (*model.PolicyTicketData, error) {
|
||||
c.mu.RLock()
|
||||
// 如果ticket存在且未过期(提前5分钟刷新)
|
||||
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
||||
ticket := c.currentTicket
|
||||
c.mu.RUnlock()
|
||||
return ticket, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// 需要获取新ticket
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// 双重检查,防止并发重复请求
|
||||
if c.currentTicket != nil && time.Now().Before(c.ticketExpiresAt.Add(-5*time.Minute)) {
|
||||
return c.currentTicket, nil
|
||||
}
|
||||
|
||||
// 发起请求获取ticket
|
||||
req := &model.PolicyTicketRequest{
|
||||
LoginName: c.loginName,
|
||||
UserKey: c.userKey,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/aiServer/getAccessUserInfo", c.baseURL)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result model.PolicyTicketResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if result.Code != 200 {
|
||||
return nil, fmt.Errorf("获取ticket失败: %s", result.Message)
|
||||
}
|
||||
|
||||
if result.Data == nil {
|
||||
return nil, fmt.Errorf("返回数据为空")
|
||||
}
|
||||
|
||||
// 缓存ticket,设置过期时间为1小时后
|
||||
c.currentTicket = result.Data
|
||||
c.ticketExpiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
return c.currentTicket, nil
|
||||
}
|
||||
|
||||
// Chat 发起政策咨询对话(非流式)
|
||||
func (c *PolicyClient) Chat(chatReq *model.PolicyChatData) (*model.PolicyChatResponse, error) {
|
||||
// 获取ticket
|
||||
ticketData, err := c.GetTicket()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ticket失败: %w", err)
|
||||
}
|
||||
|
||||
// 构造请求
|
||||
req := &model.PolicyChatRequest{
|
||||
AppID: ticketData.AppID,
|
||||
Ticket: ticketData.Ticket,
|
||||
Data: chatReq,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result model.PolicyChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if result.Code != 200 {
|
||||
return nil, fmt.Errorf("对话请求失败: %s", result.Message)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ChatStream 发起政策咨询对话(流式)
|
||||
func (c *PolicyClient) ChatStream(chatReq *model.PolicyChatData) (chan string, chan error, error) {
|
||||
// 获取ticket
|
||||
ticketData, err := c.GetTicket()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("获取ticket失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置流式模式
|
||||
chatReq.Stream = true
|
||||
|
||||
// 构造请求
|
||||
req := &model.PolicyChatRequest{
|
||||
AppID: ticketData.AppID,
|
||||
Ticket: ticketData.Ticket,
|
||||
Data: chatReq,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/aiServer/aichat/stream-ai/%s", c.baseURL, c.serviceID)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
contentChan := make(chan string, 100)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer resp.Body.Close()
|
||||
defer close(contentChan)
|
||||
defer close(errChan)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 跳过空行
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试解析为JSON响应
|
||||
var chunkResp model.PolicyChatResponse
|
||||
if err := json.Unmarshal([]byte(line), &chunkResp); err != nil {
|
||||
// 如果不是JSON格式,可能是纯文本流,直接发送
|
||||
contentChan <- line
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查响应码
|
||||
if chunkResp.Code != 200 {
|
||||
errChan <- fmt.Errorf("对话请求失败: %s", chunkResp.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送消息内容
|
||||
if chunkResp.Data != nil && chunkResp.Data.Message != "" {
|
||||
contentChan <- chunkResp.Data.Message
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
errChan <- fmt.Errorf("读取流失败: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return contentChan, errChan, nil
|
||||
}
|
||||
Reference in New Issue
Block a user