This commit is contained in:
2026-01-12 11:33:43 +08:00
commit f07062dbd7
38 changed files with 6805 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
if origin != "" {
// 当需要携带凭证Cookie/Authorization规范要求不能使用 "*"
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Vary", "Origin")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,53 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestCORS_WithOrigin_EchoOriginAndCredentials(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(CORS())
r.GET("/ping", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.Header.Set("Origin", "https://example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://example.com" {
t.Fatalf("expected allow-origin to echo origin, got %q", got)
}
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected allow-credentials true, got %q", got)
}
if got := w.Header().Get("Vary"); got != "Origin" {
t.Fatalf("expected Vary Origin, got %q", got)
}
}
func TestCORS_NoOrigin_AllowAnyOriginWithoutCredentials(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(CORS())
r.GET("/ping", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
t.Fatalf("expected allow-origin '*', got %q", got)
}
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "" {
t.Fatalf("expected allow-credentials empty, got %q", got)
}
}

View File

@@ -0,0 +1,41 @@
package middleware
import (
"qd-sc/pkg/metrics"
"time"
"github.com/gin-gonic/gin"
)
// Metrics 指标收集中间件
func Metrics() gin.HandlerFunc {
m := metrics.GetGlobalMetrics()
return func(c *gin.Context) {
// 记录请求开始时间
start := time.Now()
// 增加总请求数和活跃请求数
m.IncTotalRequests()
m.IncActiveRequests()
defer m.DecActiveRequests()
// 检查是否是流式请求
if c.GetHeader("Accept") == "text/event-stream" || c.Query("stream") == "true" {
m.IncStreamRequests()
}
// 处理请求
c.Next()
// 记录延迟
duration := time.Since(start)
endpoint := c.Request.Method + " " + c.FullPath()
m.RecordLatency(endpoint, duration)
// 如果请求失败,增加失败计数
if c.Writer.Status() >= 400 {
m.IncFailedRequests()
}
}
}

View File

@@ -0,0 +1,103 @@
package middleware
import (
"net/http"
"qd-sc/internal/model"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter 基于令牌桶的限流器(使用原子操作)
type RateLimiter struct {
tokens int64 // 当前令牌数
capacity int64 // 桶容量(最大突发请求数)
refillRate int64 // 每秒补充的令牌数持续QPS
lastRefill int64 // 上次补充时间(纳秒时间戳)
now func() int64 // 便于测试注入UnixNano
}
// NewRateLimiter 创建限流器
// capacity: 桶容量(最大突发请求数)
// refillRate: 每秒补充的令牌数持续QPS
func NewRateLimiter(capacity, refillRate int) *RateLimiter {
return &RateLimiter{
tokens: int64(capacity),
capacity: int64(capacity),
refillRate: int64(refillRate),
lastRefill: time.Now().UnixNano(),
now: func() int64 {
return time.Now().UnixNano()
},
}
}
// Allow 尝试消耗一个令牌使用CAS无锁算法
func (rl *RateLimiter) Allow() bool {
now := rl.now()
for {
// 读取当前状态
currentTokens := atomic.LoadInt64(&rl.tokens)
lastRefill := atomic.LoadInt64(&rl.lastRefill)
// 计算应该补充的令牌
elapsed := now - lastRefill
if elapsed < 0 {
// 时钟回拨等极端情况:不补充
elapsed = 0
}
// 安全计算:避免 elapsed * refillRate 直接相乘造成溢出
// tokensToAdd = floor(elapsed_ns * refillRate_per_sec / 1e9)
secPart := elapsed / int64(time.Second) // elapsed 秒
nsecPart := elapsed % int64(time.Second) // 剩余纳秒
tokensToAdd := secPart*rl.refillRate + (nsecPart*rl.refillRate)/int64(time.Second)
newTokens := currentTokens
if tokensToAdd > 0 {
newTokens = currentTokens + tokensToAdd
if newTokens > rl.capacity {
newTokens = rl.capacity
}
}
// 检查是否有令牌可用
if newTokens < 1 {
return false
}
// 尝试消耗一个令牌
if atomic.CompareAndSwapInt64(&rl.tokens, currentTokens, newTokens-1) {
// 更新最后补充时间
if tokensToAdd > 0 {
atomic.StoreInt64(&rl.lastRefill, now)
}
return true
}
// CAS失败重试
}
}
// RateLimit 限流中间件
func RateLimit(capacity, refillRate int) gin.HandlerFunc {
limiter := NewRateLimiter(capacity, refillRate)
return func(c *gin.Context) {
if !limiter.Allow() {
c.JSON(http.StatusTooManyRequests, model.ErrorResponse{
Error: model.ErrorDetail{
Message: "请求过于频繁,请稍后再试",
Type: "rate_limit_exceeded",
},
})
c.Abort()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,35 @@
package middleware
import (
"testing"
"time"
)
func TestRateLimiter_AllowAndRefill(t *testing.T) {
rl := NewRateLimiter(2, 1) // 容量2每秒补充1
// 注入可控时间
now := time.Unix(0, 0).UnixNano()
rl.now = func() int64 { return now }
rl.lastRefill = now
rl.tokens = 2
if !rl.Allow() {
t.Fatalf("expected first Allow() true")
}
if !rl.Allow() {
t.Fatalf("expected second Allow() true")
}
if rl.Allow() {
t.Fatalf("expected third Allow() false (no tokens)")
}
// 过 1 秒应补充 1 个令牌
now += int64(time.Second)
if !rl.Allow() {
t.Fatalf("expected Allow() true after refill")
}
if rl.Allow() {
t.Fatalf("expected Allow() false again (tokens should be 0)")
}
}

View File

@@ -0,0 +1,35 @@
package middleware
import (
"log"
"net/http"
"qd-sc/internal/model"
"runtime/debug"
"github.com/gin-gonic/gin"
)
// Recovery Panic恢复中间件
func Recovery() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// 打印堆栈信息
stack := debug.Stack()
log.Printf("[PANIC] %v\n%s", err, string(stack))
// 返回500错误
c.JSON(http.StatusInternalServerError, model.ErrorResponse{
Error: model.ErrorDetail{
Message: "服务器内部错误",
Type: "internal_server_error",
},
})
c.Abort()
}
}()
c.Next()
}
}