Files
ai_job_chat_agent/internal/client/llm_client.go

157 lines
4.0 KiB
Go
Raw Permalink Normal View History

2026-01-12 11:33:43 +08:00
package client
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"qd-sc/internal/config"
"qd-sc/internal/model"
"strings"
"time"
)
// LLMClient LLM客户端
type LLMClient struct {
baseURL string
apiKey string
httpClient *http.Client
maxRetries int
}
// NewLLMClient 创建LLM客户端
func NewLLMClient(cfg *config.Config) *LLMClient {
return &LLMClient{
baseURL: cfg.LLM.BaseURL,
apiKey: cfg.LLM.APIKey,
httpClient: NewLLMHTTPClient(cfg.LLM.Timeout),
maxRetries: cfg.LLM.MaxRetries,
}
}
// ChatCompletion 发起聊天补全请求(非流式)
func (c *LLMClient) ChatCompletion(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
var resp *http.Response
var lastErr error
// 重试机制
for i := 0; i < c.maxRetries; i++ {
// 注意http.Request 的 Body 在 Do() 后会被读取消耗,重试必须重建 request/body
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, lastErr = c.httpClient.Do(httpReq)
// 请求未发出/网络错误:可重试
if lastErr != nil {
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
// 根据状态码决定是否重试5xx 和 429 重试;其他 4xx 直接返回
if resp.StatusCode < 500 && resp.StatusCode != http.StatusTooManyRequests {
break
}
resp.Body.Close()
time.Sleep(time.Duration(i+1) * time.Second)
}
if lastErr != nil {
return nil, fmt.Errorf("HTTP请求失败: %w", lastErr)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
var result model.ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return &result, nil
}
// ChatCompletionStream 发起聊天补全请求(流式)
func (c *LLMClient) ChatCompletionStream(req *model.ChatCompletionRequest) (chan *model.ChatCompletionChunk, chan error, error) {
req.Stream = true
reqBody, err := json.Marshal(req)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
if err != nil {
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
httpReq.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
chunkChan := make(chan *model.ChatCompletionChunk, 100)
errChan := make(chan error, 1)
go func() {
defer resp.Body.Close()
defer close(chunkChan)
defer close(errChan)
scanner := bufio.NewScanner(resp.Body)
// 设置最大buffer大小防止被超大响应攻击默认64KB增加到1MB
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return
}
var chunk model.ChatCompletionChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
errChan <- fmt.Errorf("解析流式响应失败: %w", err)
return
}
chunkChan <- &chunk
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("读取流失败: %w", err)
}
}()
return chunkChan, errChan, nil
}