init
This commit is contained in:
157
internal/api/handler/chat.go
Normal file
157
internal/api/handler/chat.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"qd-sc/internal/model"
|
||||
"qd-sc/internal/service"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 使用服务层定义的固定模型名称
|
||||
|
||||
// ChatHandler 聊天处理器
|
||||
type ChatHandler struct {
|
||||
chatService *service.ChatService
|
||||
response *Response
|
||||
}
|
||||
|
||||
// NewChatHandler 创建聊天处理器
|
||||
func NewChatHandler(chatService *service.ChatService) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
chatService: chatService,
|
||||
response: DefaultResponse,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions 处理聊天completions请求
|
||||
func (h *ChatHandler) ChatCompletions(c *gin.Context) {
|
||||
var req model.ChatCompletionRequest
|
||||
|
||||
// 只支持 JSON 请求(文件通过 image_url 字段以 URL 方式传递)
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.response.Error(c, http.StatusBadRequest, "invalid_request", "无效的请求格式: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
if req.Model == "" {
|
||||
h.response.Error(c, http.StatusBadRequest, "invalid_request", "缺少model参数")
|
||||
return
|
||||
}
|
||||
// 验证模型名称(只接受固定模型名)
|
||||
if req.Model != service.ExposedModelName {
|
||||
h.response.Error(c, http.StatusBadRequest, "invalid_request", fmt.Sprintf("不支持的模型,请使用: %s", service.ExposedModelName))
|
||||
return
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
h.response.Error(c, http.StatusBadRequest, "invalid_request", "messages不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 根据stream参数决定返回方式
|
||||
if req.Stream {
|
||||
h.handleStreamResponse(c, &req)
|
||||
} else {
|
||||
h.handleNonStreamResponse(c, &req)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamResponse 处理非流式响应
|
||||
func (h *ChatHandler) handleNonStreamResponse(c *gin.Context, req *model.ChatCompletionRequest) {
|
||||
resp, err := h.chatService.ProcessChatRequest(req)
|
||||
if err != nil {
|
||||
log.Printf("处理聊天请求失败: %v", err)
|
||||
h.response.Error(c, http.StatusInternalServerError, "internal_error", "处理请求失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp)
|
||||
}
|
||||
|
||||
// handleStreamResponse 处理流式响应
|
||||
func (h *ChatHandler) handleStreamResponse(c *gin.Context, req *model.ChatCompletionRequest) {
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Transfer-Encoding", "chunked")
|
||||
|
||||
// 传递context以支持取消
|
||||
ctx := c.Request.Context()
|
||||
chunkChan, errChan := h.chatService.ProcessChatRequestStream(ctx, req)
|
||||
|
||||
// 持续发送流式数据
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-chunkChan:
|
||||
if !ok {
|
||||
// 通道已关闭,发送[DONE]标记(OpenAI标准格式)
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: [DONE]\n\n"); err != nil {
|
||||
log.Printf("写入[DONE]标记失败: %v", err)
|
||||
}
|
||||
c.Writer.Flush()
|
||||
log.Printf("SSE流已结束,已发送[DONE]标记")
|
||||
return
|
||||
}
|
||||
|
||||
// 发送chunk
|
||||
chunkJSON, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
log.Printf("序列化chunk失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入SSE格式
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunkJSON)); err != nil {
|
||||
log.Printf("写入SSE数据失败: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
|
||||
// 如果这个chunk包含finish_reason,记录日志
|
||||
if len(chunk.Choices) > 0 && chunk.Choices[0].FinishReason != "" {
|
||||
log.Printf("已发送finish_reason=%s的chunk", chunk.Choices[0].FinishReason)
|
||||
}
|
||||
|
||||
case err, ok := <-errChan:
|
||||
if ok && err != nil {
|
||||
log.Printf("流式处理错误: %v", err)
|
||||
// 发送错误信息
|
||||
errChunk := model.ChatCompletionChunk{
|
||||
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: service.ExposedModelName,
|
||||
Choices: []model.ChunkChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: model.Message{
|
||||
Role: "assistant",
|
||||
Content: fmt.Sprintf("\n\n错误:%s", err.Error()),
|
||||
},
|
||||
FinishReason: "error",
|
||||
},
|
||||
},
|
||||
}
|
||||
chunkJSON, _ := json.Marshal(errChunk)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunkJSON))
|
||||
c.Writer.Flush()
|
||||
|
||||
// 发送DONE
|
||||
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
log.Printf("客户端断开连接")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
23
internal/api/handler/health.go
Normal file
23
internal/api/handler/health.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HealthHandler 健康检查处理器
|
||||
type HealthHandler struct{}
|
||||
|
||||
// NewHealthHandler 创建健康检查处理器
|
||||
func NewHealthHandler() *HealthHandler {
|
||||
return &HealthHandler{}
|
||||
}
|
||||
|
||||
// Check 健康检查
|
||||
func (h *HealthHandler) Check(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"service": "qd-sc-server",
|
||||
})
|
||||
}
|
||||
26
internal/api/handler/metrics.go
Normal file
26
internal/api/handler/metrics.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"qd-sc/pkg/metrics"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MetricsHandler 指标处理器
|
||||
type MetricsHandler struct {
|
||||
metrics *metrics.Metrics
|
||||
}
|
||||
|
||||
// NewMetricsHandler 创建指标处理器
|
||||
func NewMetricsHandler() *MetricsHandler {
|
||||
return &MetricsHandler{
|
||||
metrics: metrics.GetGlobalMetrics(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics 获取性能指标
|
||||
func (h *MetricsHandler) GetMetrics(c *gin.Context) {
|
||||
stats := h.metrics.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
33
internal/api/handler/response.go
Normal file
33
internal/api/handler/response.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"qd-sc/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 统一响应处理器
|
||||
type Response struct{}
|
||||
|
||||
// Error 发送错误响应
|
||||
func (r *Response) Error(c *gin.Context, statusCode int, errorType, message string) {
|
||||
c.JSON(statusCode, model.ErrorResponse{
|
||||
Error: model.ErrorDetail{
|
||||
Message: message,
|
||||
Type: errorType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Success 发送成功响应
|
||||
func (r *Response) Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(200, data)
|
||||
}
|
||||
|
||||
// NewResponse 创建响应处理器
|
||||
func NewResponse() *Response {
|
||||
return &Response{}
|
||||
}
|
||||
|
||||
// 全局响应处理器实例
|
||||
var DefaultResponse = NewResponse()
|
||||
33
internal/api/middleware/cors.go
Normal file
33
internal/api/middleware/cors.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.GetHeader("Origin")
|
||||
if origin != "" {
|
||||
// 当需要携带凭证(Cookie/Authorization)时,规范要求不能使用 "*"
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Set("Vary", "Origin")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
} else {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
53
internal/api/middleware/cors_test.go
Normal file
53
internal/api/middleware/cors_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCORS_WithOrigin_EchoOriginAndCredentials(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.GET("/ping", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://example.com" {
|
||||
t.Fatalf("expected allow-origin to echo origin, got %q", got)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||
t.Fatalf("expected allow-credentials true, got %q", got)
|
||||
}
|
||||
if got := w.Header().Get("Vary"); got != "Origin" {
|
||||
t.Fatalf("expected Vary Origin, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_NoOrigin_AllowAnyOriginWithoutCredentials(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.GET("/ping", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Fatalf("expected allow-origin '*', got %q", got)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "" {
|
||||
t.Fatalf("expected allow-credentials empty, got %q", got)
|
||||
}
|
||||
}
|
||||
41
internal/api/middleware/metrics.go
Normal file
41
internal/api/middleware/metrics.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"qd-sc/pkg/metrics"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Metrics 指标收集中间件
|
||||
func Metrics() gin.HandlerFunc {
|
||||
m := metrics.GetGlobalMetrics()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 记录请求开始时间
|
||||
start := time.Now()
|
||||
|
||||
// 增加总请求数和活跃请求数
|
||||
m.IncTotalRequests()
|
||||
m.IncActiveRequests()
|
||||
defer m.DecActiveRequests()
|
||||
|
||||
// 检查是否是流式请求
|
||||
if c.GetHeader("Accept") == "text/event-stream" || c.Query("stream") == "true" {
|
||||
m.IncStreamRequests()
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录延迟
|
||||
duration := time.Since(start)
|
||||
endpoint := c.Request.Method + " " + c.FullPath()
|
||||
m.RecordLatency(endpoint, duration)
|
||||
|
||||
// 如果请求失败,增加失败计数
|
||||
if c.Writer.Status() >= 400 {
|
||||
m.IncFailedRequests()
|
||||
}
|
||||
}
|
||||
}
|
||||
103
internal/api/middleware/ratelimit.go
Normal file
103
internal/api/middleware/ratelimit.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"qd-sc/internal/model"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter 基于令牌桶的限流器(使用原子操作)
|
||||
type RateLimiter struct {
|
||||
tokens int64 // 当前令牌数
|
||||
capacity int64 // 桶容量(最大突发请求数)
|
||||
refillRate int64 // 每秒补充的令牌数(持续QPS)
|
||||
lastRefill int64 // 上次补充时间(纳秒时间戳)
|
||||
|
||||
now func() int64 // 便于测试注入(UnixNano)
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
// capacity: 桶容量(最大突发请求数)
|
||||
// refillRate: 每秒补充的令牌数(持续QPS)
|
||||
func NewRateLimiter(capacity, refillRate int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
tokens: int64(capacity),
|
||||
capacity: int64(capacity),
|
||||
refillRate: int64(refillRate),
|
||||
lastRefill: time.Now().UnixNano(),
|
||||
now: func() int64 {
|
||||
return time.Now().UnixNano()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 尝试消耗一个令牌(使用CAS无锁算法)
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
now := rl.now()
|
||||
|
||||
for {
|
||||
// 读取当前状态
|
||||
currentTokens := atomic.LoadInt64(&rl.tokens)
|
||||
lastRefill := atomic.LoadInt64(&rl.lastRefill)
|
||||
|
||||
// 计算应该补充的令牌
|
||||
elapsed := now - lastRefill
|
||||
if elapsed < 0 {
|
||||
// 时钟回拨等极端情况:不补充
|
||||
elapsed = 0
|
||||
}
|
||||
|
||||
// 安全计算:避免 elapsed * refillRate 直接相乘造成溢出
|
||||
// tokensToAdd = floor(elapsed_ns * refillRate_per_sec / 1e9)
|
||||
secPart := elapsed / int64(time.Second) // elapsed 秒
|
||||
nsecPart := elapsed % int64(time.Second) // 剩余纳秒
|
||||
tokensToAdd := secPart*rl.refillRate + (nsecPart*rl.refillRate)/int64(time.Second)
|
||||
|
||||
newTokens := currentTokens
|
||||
if tokensToAdd > 0 {
|
||||
newTokens = currentTokens + tokensToAdd
|
||||
if newTokens > rl.capacity {
|
||||
newTokens = rl.capacity
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有令牌可用
|
||||
if newTokens < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 尝试消耗一个令牌
|
||||
if atomic.CompareAndSwapInt64(&rl.tokens, currentTokens, newTokens-1) {
|
||||
// 更新最后补充时间
|
||||
if tokensToAdd > 0 {
|
||||
atomic.StoreInt64(&rl.lastRefill, now)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CAS失败,重试
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit 限流中间件
|
||||
func RateLimit(capacity, refillRate int) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(capacity, refillRate)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
if !limiter.Allow() {
|
||||
c.JSON(http.StatusTooManyRequests, model.ErrorResponse{
|
||||
Error: model.ErrorDetail{
|
||||
Message: "请求过于频繁,请稍后再试",
|
||||
Type: "rate_limit_exceeded",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
35
internal/api/middleware/ratelimit_test.go
Normal file
35
internal/api/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter_AllowAndRefill(t *testing.T) {
|
||||
rl := NewRateLimiter(2, 1) // 容量2,每秒补充1
|
||||
|
||||
// 注入可控时间
|
||||
now := time.Unix(0, 0).UnixNano()
|
||||
rl.now = func() int64 { return now }
|
||||
rl.lastRefill = now
|
||||
rl.tokens = 2
|
||||
|
||||
if !rl.Allow() {
|
||||
t.Fatalf("expected first Allow() true")
|
||||
}
|
||||
if !rl.Allow() {
|
||||
t.Fatalf("expected second Allow() true")
|
||||
}
|
||||
if rl.Allow() {
|
||||
t.Fatalf("expected third Allow() false (no tokens)")
|
||||
}
|
||||
|
||||
// 过 1 秒应补充 1 个令牌
|
||||
now += int64(time.Second)
|
||||
if !rl.Allow() {
|
||||
t.Fatalf("expected Allow() true after refill")
|
||||
}
|
||||
if rl.Allow() {
|
||||
t.Fatalf("expected Allow() false again (tokens should be 0)")
|
||||
}
|
||||
}
|
||||
35
internal/api/middleware/recovery.go
Normal file
35
internal/api/middleware/recovery.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"qd-sc/internal/model"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery Panic恢复中间件
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// 打印堆栈信息
|
||||
stack := debug.Stack()
|
||||
log.Printf("[PANIC] %v\n%s", err, string(stack))
|
||||
|
||||
// 返回500错误
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse{
|
||||
Error: model.ErrorDetail{
|
||||
Message: "服务器内部错误",
|
||||
Type: "internal_server_error",
|
||||
},
|
||||
})
|
||||
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user