"""
API限流器
"""
import time
from typing import Dict, Optional
from collections import defaultdict
from dataclasses import dataclass
from loguru import logger
from fastapi import HTTPException, Request


@dataclass
class RateLimitConfig:
    """限流配置"""
    max_requests: int = 100  # 最大请求数
    window_seconds: int = 60  # 时间窗口（秒）
    key_func: Optional[callable] = None  # 限流键生成函数


class TokenBucket:
    """令牌桶算法"""
    
    def __init__(self, capacity: int, refill_rate: float):
        self.capacity = capacity
        self.refill_rate = refill_rate  # 每秒补充的令牌数
        self.tokens = capacity
        self.last_refill = time.time()
    
    def consume(self, tokens: int = 1) -> bool:
        """消费令牌"""
        self._refill()
        
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False
    
    def _refill(self):
        """补充令牌"""
        now = time.time()
        elapsed = now - self.last_refill
        tokens_to_add = elapsed * self.refill_rate
        self.tokens = min(self.capacity, self.tokens + tokens_to_add)
        self.last_refill = now
    
    def get_available_tokens(self) -> int:
        """获取可用令牌数"""
        self._refill()
        return int(self.tokens)


class SlidingWindow:
    """滑动窗口算法"""
    
    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests: Dict[float, int] = {}  # timestamp -> count
    
    def is_allowed(self) -> bool:
        """检查是否允许请求"""
        now = time.time()
        window_start = now - self.window_seconds
        
        # 清理过期记录
        self.requests = {
            ts: count for ts, count in self.requests.items()
            if ts > window_start
        }
        
        # 计算窗口内的请求数
        total_requests = sum(self.requests.values())
        
        if total_requests < self.max_requests:
            # 记录本次请求
            current_second = int(now)
            self.requests[current_second] = self.requests.get(current_second, 0) + 1
            return True
        
        return False
    
    def get_remaining(self) -> int:
        """获取剩余请求数"""
        now = time.time()
        window_start = now - self.window_seconds
        
        # 清理过期记录
        self.requests = {
            ts: count for ts, count in self.requests.items()
            if ts > window_start
        }
        
        total_requests = sum(self.requests.values())
        return max(0, self.max_requests - total_requests)


class APIRateLimiter:
    """API限流器"""
    
    def __init__(self):
        self._limiters: Dict[str, SlidingWindow] = {}
        self._default_config = RateLimitConfig()
    
    def add_limit(self, key: str, max_requests: int, window_seconds: int):
        """添加限流规则"""
        self._limiters[key] = SlidingWindow(max_requests, window_seconds)
        logger.info(f"添加限流规则: {key} - {max_requests}请求/{window_seconds}秒")
    
    def check_limit(self, key: str, config: RateLimitConfig = None) -> bool:
        """检查限流"""
        config = config or self._default_config
        
        if key not in self._limiters:
            # 创建默认限流器
            self._limiters[key] = SlidingWindow(
                config.max_requests,
                config.window_seconds
            )
        
        limiter = self._limiters[key]
        return limiter.is_allowed()
    
    def get_remaining(self, key: str) -> int:
        """获取剩余请求数"""
        if key not in self._limiters:
            return self._default_config.max_requests
        
        return self._limiters[key].get_remaining()
    
    def reset_limit(self, key: str):
        """重置限流"""
        if key in self._limiters:
            del self._limiters[key]
            logger.info(f"重置限流规则: {key}")


class RateLimitMiddleware:
    """限流中间件"""
    
    def __init__(self, rate_limiter: APIRateLimiter):
        self.rate_limiter = rate_limiter
    
    def get_client_key(self, request: Request) -> str:
        """获取客户端标识"""
        # 优先使用IP地址
        client_ip = request.client.host if request.client else "unknown"
        
        # 如果有认证信息，可以使用用户ID
        # user_id = getattr(request.state, 'user_id', None)
        # if user_id:
        #     return f"user:{user_id}"
        
        return f"ip:{client_ip}"
    
    async def __call__(self, request: Request, call_next):
        """中间件处理"""
        # 获取客户端标识
        client_key = self.get_client_key(request)
        
        # 检查限流
        if not self.rate_limiter.check_limit(client_key):
            remaining = self.rate_limiter.get_remaining(client_key)
            raise HTTPException(
                status_code=429,
                detail=f"请求过于频繁，请稍后再试。剩余请求数: {remaining}",
                headers={
                    "X-RateLimit-Limit": str(self.rate_limiter._default_config.max_requests),
                    "X-RateLimit-Remaining": str(remaining),
                    "X-RateLimit-Reset": str(int(time.time()) + self.rate_limiter._default_config.window_seconds),
                }
            )
        
        # 继续处理请求
        response = await call_next(request)
        
        # 添加限流头信息
        remaining = self.rate_limiter.get_remaining(client_key)
        response.headers["X-RateLimit-Limit"] = str(self.rate_limiter._default_config.max_requests)
        response.headers["X-RateLimit-Remaining"] = str(remaining)
        response.headers["X-RateLimit-Reset"] = str(int(time.time()) + self.rate_limiter._default_config.window_seconds)
        
        return response

