Files
ai_job_chat_agent/internal/client/llm_client.go
2026-01-12 11:33:43 +08:00

157 lines
4.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package client
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"qd-sc/internal/config"
"qd-sc/internal/model"
"strings"
"time"
)
// LLMClient LLM客户端
type LLMClient struct {
baseURL string
apiKey string
httpClient *http.Client
maxRetries int
}
// NewLLMClient 创建LLM客户端
func NewLLMClient(cfg *config.Config) *LLMClient {
return &LLMClient{
baseURL: cfg.LLM.BaseURL,
apiKey: cfg.LLM.APIKey,
httpClient: NewLLMHTTPClient(cfg.LLM.Timeout),
maxRetries: cfg.LLM.MaxRetries,
}
}
// ChatCompletion 发起聊天补全请求(非流式)
func (c *LLMClient) ChatCompletion(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
var resp *http.Response
var lastErr error
// 重试机制
for i := 0; i < c.maxRetries; i++ {
// 注意http.Request 的 Body 在 Do() 后会被读取消耗,重试必须重建 request/body
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, lastErr = c.httpClient.Do(httpReq)
// 请求未发出/网络错误:可重试
if lastErr != nil {
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
// 根据状态码决定是否重试5xx 和 429 重试;其他 4xx 直接返回
if resp.StatusCode < 500 && resp.StatusCode != http.StatusTooManyRequests {
break
}
resp.Body.Close()
time.Sleep(time.Duration(i+1) * time.Second)
}
if lastErr != nil {
return nil, fmt.Errorf("HTTP请求失败: %w", lastErr)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
var result model.ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return &result, nil
}
// ChatCompletionStream 发起聊天补全请求(流式)
func (c *LLMClient) ChatCompletionStream(req *model.ChatCompletionRequest) (chan *model.ChatCompletionChunk, chan error, error) {
req.Stream = true
reqBody, err := json.Marshal(req)
if err != nil {
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
}
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
if err != nil {
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
httpReq.Header.Set("Accept", "text/event-stream")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
}
chunkChan := make(chan *model.ChatCompletionChunk, 100)
errChan := make(chan error, 1)
go func() {
defer resp.Body.Close()
defer close(chunkChan)
defer close(errChan)
scanner := bufio.NewScanner(resp.Body)
// 设置最大buffer大小防止被超大响应攻击默认64KB增加到1MB
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return
}
var chunk model.ChatCompletionChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
errChan <- fmt.Errorf("解析流式响应失败: %w", err)
return
}
chunkChan <- &chunk
}
if err := scanner.Err(); err != nil {
errChan <- fmt.Errorf("读取流失败: %w", err)
}
}()
return chunkChan, errChan, nil
}