init
This commit is contained in:
156
internal/client/llm_client.go
Normal file
156
internal/client/llm_client.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"qd-sc/internal/config"
|
||||
"qd-sc/internal/model"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LLMClient LLM客户端
|
||||
type LLMClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
maxRetries int
|
||||
}
|
||||
|
||||
// NewLLMClient 创建LLM客户端
|
||||
func NewLLMClient(cfg *config.Config) *LLMClient {
|
||||
return &LLMClient{
|
||||
baseURL: cfg.LLM.BaseURL,
|
||||
apiKey: cfg.LLM.APIKey,
|
||||
httpClient: NewLLMHTTPClient(cfg.LLM.Timeout),
|
||||
maxRetries: cfg.LLM.MaxRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletion 发起聊天补全请求(非流式)
|
||||
func (c *LLMClient) ChatCompletion(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var lastErr error
|
||||
|
||||
// 重试机制
|
||||
for i := 0; i < c.maxRetries; i++ {
|
||||
// 注意:http.Request 的 Body 在 Do() 后会被读取消耗,重试必须重建 request/body
|
||||
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
resp, lastErr = c.httpClient.Do(httpReq)
|
||||
|
||||
// 请求未发出/网络错误:可重试
|
||||
if lastErr != nil {
|
||||
time.Sleep(time.Duration(i+1) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
// 根据状态码决定是否重试:5xx 和 429 重试;其他 4xx 直接返回
|
||||
if resp.StatusCode < 500 && resp.StatusCode != http.StatusTooManyRequests {
|
||||
break
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
time.Sleep(time.Duration(i+1) * time.Second)
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, fmt.Errorf("HTTP请求失败: %w", lastErr)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result model.ChatCompletionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ChatCompletionStream 发起聊天补全请求(流式)
|
||||
func (c *LLMClient) ChatCompletionStream(req *model.ChatCompletionRequest) (chan *model.ChatCompletionChunk, chan error, error) {
|
||||
req.Stream = true
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.baseURL+"/chat/completions", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("HTTP请求失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("API返回错误状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
chunkChan := make(chan *model.ChatCompletionChunk, 100)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer resp.Body.Close()
|
||||
defer close(chunkChan)
|
||||
defer close(errChan)
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// 设置最大buffer大小,防止被超大响应攻击(默认64KB,增加到1MB)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var chunk model.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
errChan <- fmt.Errorf("解析流式响应失败: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
chunkChan <- &chunk
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
errChan <- fmt.Errorf("读取流失败: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return chunkChan, errChan, nil
|
||||
}
|
||||
Reference in New Issue
Block a user