"""
速率限制器
"""
import time
from typing import Dict, Optional
from collections import defaultdict
from dataclasses import dataclass
from loguru import logger


@dataclass
class RateLimit:
    """速率限制配置"""
    max_requests: int  # 最大请求数
    window_seconds: int  # 时间窗口（秒）


class RateLimiter:
    """速率限制器"""
    
    def __init__(self):
        self._limits: Dict[str, RateLimit] = {}
        self._counters: Dict[str, list] = defaultdict(list)  # key -> [timestamp, ...]
    
    def set_limit(self, key: str, max_requests: int, window_seconds: int):
        """设置速率限制"""
        self._limits[key] = RateLimit(max_requests=max_requests, window_seconds=window_seconds)
        logger.info(f"设置速率限制: {key} = {max_requests}/{window_seconds}s")
    
    def remove_limit(self, key: str):
        """移除速率限制"""
        if key in self._limits:
            del self._limits[key]
            if key in self._counters:
                del self._counters[key]
            logger.info(f"移除速率限制: {key}")
    
    def is_allowed(self, key: str) -> tuple[bool, Optional[dict]]:
        """检查是否允许请求
        
        Returns:
            (是否允许, 限制信息)
        """
        if key not in self._limits:
            return True, None
        
        limit = self._limits[key]
        current_time = time.time()
        
        # 清理过期记录
        self._counters[key] = [
            ts for ts in self._counters[key]
            if current_time - ts < limit.window_seconds
        ]
        
        # 检查是否超过限制
        if len(self._counters[key]) >= limit.max_requests:
            return False, {
                'limit': limit.max_requests,
                'window': limit.window_seconds,
                'remaining': 0,
                'reset_at': min(self._counters[key]) + limit.window_seconds if self._counters[key] else current_time,
            }
        
        # 记录本次请求
        self._counters[key].append(current_time)
        
        return True, {
            'limit': limit.max_requests,
            'window': limit.window_seconds,
            'remaining': limit.max_requests - len(self._counters[key]),
            'reset_at': min(self._counters[key]) + limit.window_seconds if self._counters[key] else current_time,
        }
    
    def get_stats(self, key: str) -> Optional[dict]:
        """获取速率限制统计"""
        if key not in self._limits:
            return None
        
        limit = self._limits[key]
        current_time = time.time()
        
        # 清理过期记录
        self._counters[key] = [
            ts for ts in self._counters[key]
            if current_time - ts < limit.window_seconds
        ]
        
        return {
            'limit': limit.max_requests,
            'window': limit.window_seconds,
            'used': len(self._counters[key]),
            'remaining': limit.max_requests - len(self._counters[key]),
        }
    
    def reset(self, key: str):
        """重置计数器"""
        if key in self._counters:
            self._counters[key].clear()
            logger.info(f"重置速率限制计数器: {key}")

