157 lines
4.0 KiB
Go
157 lines
4.0 KiB
Go
|
|
package client
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bufio"
|
|||
|
|
"bytes"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"qd-sc/internal/config"
|
|||
|
|
"qd-sc/internal/model"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// LLMClient LLM客户端
|
|||
|
|
type LLMClient struct {
|
|||
|
|
baseURL string
|
|||
|
|
apiKey string
|
|||
|
|
httpClient *http.Client
|
|||
|
|
maxRetries int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewLLMClient 创建LLM客户端
|
|||
|
|
func NewLLMClient(cfg *config.Config) *LLMClient {
|
|||
|
|
return &LLMClient{
|
|||
|
|
baseURL: cfg.LLM.BaseURL,
|
|||
|
|
apiKey: cfg.LLM.APIKey,
|
|||
|
|
httpClient: NewLLMHTTPClient(cfg.LLM.Timeout),
|
|||
|
|
maxRetries: cfg.LLM.MaxRetries,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatCompletion 发起聊天补全请求(非流式)
|
|||
|
|
func (c *LLMClient) ChatCompletion(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
|||
|
|
reqBody, err := json.Marshal(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var resp *http.Response
|
|||
|
|
var lastErr error
|
|||
|
|
|
|||
|
|
// 重试机制
|
|||
|
|
for i := 0; i < c.maxRetries; i++ {
|
|||
|
|
// 注意:http.Request 的 Body 在 Do() 后会被读取消耗,重试必须重建 request/body
|
|||
|
|
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|||
|
|
|
|||
|
|
resp, lastErr = c.httpClient.Do(httpReq)
|
|||
|
|
|
|||
|
|
// 请求未发出/网络错误:可重试
|
|||
|
|
if lastErr != nil {
|
|||
|
|
time.Sleep(time.Duration(i+1) * time.Second)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 根据状态码决定是否重试:5xx 和 429 重试;其他 4xx 直接返回
|
|||
|
|
if resp.StatusCode < 500 && resp.StatusCode != http.StatusTooManyRequests {
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
resp.Body.Close()
|
|||
|
|
time.Sleep(time.Duration(i+1) * time.Second)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if lastErr != nil {
|
|||
|
|
return nil, fmt.Errorf("HTTP请求失败: %w", lastErr)
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var result model.ChatCompletionResponse
|
|||
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ChatCompletionStream 发起聊天补全请求(流式)
|
|||
|
|
func (c *LLMClient) ChatCompletionStream(req *model.ChatCompletionRequest) (chan *model.ChatCompletionChunk, chan error, error) {
|
|||
|
|
req.Stream = true
|
|||
|
|
reqBody, err := json.Marshal(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|||
|
|
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|||
|
|
httpReq.Header.Set("Accept", "text/event-stream")
|
|||
|
|
|
|||
|
|
resp, err := c.httpClient.Do(httpReq)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if resp.StatusCode != http.StatusOK {
|
|||
|
|
body, _ := io.ReadAll(resp.Body)
|
|||
|
|
resp.Body.Close()
|
|||
|
|
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
chunkChan := make(chan *model.ChatCompletionChunk, 100)
|
|||
|
|
errChan := make(chan error, 1)
|
|||
|
|
|
|||
|
|
go func() {
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
defer close(chunkChan)
|
|||
|
|
defer close(errChan)
|
|||
|
|
|
|||
|
|
scanner := bufio.NewScanner(resp.Body)
|
|||
|
|
// 设置最大buffer大小,防止被超大响应攻击(默认64KB,增加到1MB)
|
|||
|
|
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
|||
|
|
|
|||
|
|
for scanner.Scan() {
|
|||
|
|
line := scanner.Text()
|
|||
|
|
if !strings.HasPrefix(line, "data: ") {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
data := strings.TrimPrefix(line, "data: ")
|
|||
|
|
if data == "[DONE]" {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var chunk model.ChatCompletionChunk
|
|||
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|||
|
|
errChan <- fmt.Errorf("解析流式响应失败: %w", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
chunkChan <- &chunk
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := scanner.Err(); err != nil {
|
|||
|
|
errChan <- fmt.Errorf("读取流失败: %w", err)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
return chunkChan, errChan, nil
|
|||
|
|
}
|