157 lines
4.0 KiB
Go
157 lines
4.0 KiB
Go
package client
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"qd-sc/internal/config"
|
||
"qd-sc/internal/model"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// LLMClient LLM客户端
|
||
type LLMClient struct {
|
||
baseURL string
|
||
apiKey string
|
||
httpClient *http.Client
|
||
maxRetries int
|
||
}
|
||
|
||
// NewLLMClient 创建LLM客户端
|
||
func NewLLMClient(cfg *config.Config) *LLMClient {
|
||
return &LLMClient{
|
||
baseURL: cfg.LLM.BaseURL,
|
||
apiKey: cfg.LLM.APIKey,
|
||
httpClient: NewLLMHTTPClient(cfg.LLM.Timeout),
|
||
maxRetries: cfg.LLM.MaxRetries,
|
||
}
|
||
}
|
||
|
||
// ChatCompletion 发起聊天补全请求(非流式)
|
||
func (c *LLMClient) ChatCompletion(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
var resp *http.Response
|
||
var lastErr error
|
||
|
||
// 重试机制
|
||
for i := 0; i < c.maxRetries; i++ {
|
||
// 注意:http.Request 的 Body 在 Do() 后会被读取消耗,重试必须重建 request/body
|
||
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||
|
||
resp, lastErr = c.httpClient.Do(httpReq)
|
||
|
||
// 请求未发出/网络错误:可重试
|
||
if lastErr != nil {
|
||
time.Sleep(time.Duration(i+1) * time.Second)
|
||
continue
|
||
}
|
||
|
||
// 根据状态码决定是否重试:5xx 和 429 重试;其他 4xx 直接返回
|
||
if resp.StatusCode < 500 && resp.StatusCode != http.StatusTooManyRequests {
|
||
break
|
||
}
|
||
|
||
resp.Body.Close()
|
||
time.Sleep(time.Duration(i+1) * time.Second)
|
||
}
|
||
|
||
if lastErr != nil {
|
||
return nil, fmt.Errorf("HTTP请求失败: %w", lastErr)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var result model.ChatCompletionResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
return &result, nil
|
||
}
|
||
|
||
// ChatCompletionStream 发起聊天补全请求(流式)
|
||
func (c *LLMClient) ChatCompletionStream(req *model.ChatCompletionRequest) (chan *model.ChatCompletionChunk, chan error, error) {
|
||
req.Stream = true
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||
httpReq.Header.Set("Accept", "text/event-stream")
|
||
|
||
resp, err := c.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
chunkChan := make(chan *model.ChatCompletionChunk, 100)
|
||
errChan := make(chan error, 1)
|
||
|
||
go func() {
|
||
defer resp.Body.Close()
|
||
defer close(chunkChan)
|
||
defer close(errChan)
|
||
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
// 设置最大buffer大小,防止被超大响应攻击(默认64KB,增加到1MB)
|
||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
if data == "[DONE]" {
|
||
return
|
||
}
|
||
|
||
var chunk model.ChatCompletionChunk
|
||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||
errChan <- fmt.Errorf("解析流式响应失败: %w", err)
|
||
return
|
||
}
|
||
|
||
chunkChan <- &chunk
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
errChan <- fmt.Errorf("读取流失败: %w", err)
|
||
}
|
||
}()
|
||
|
||
return chunkChan, errChan, nil
|
||
}
|