"""
限流器单元测试
"""
import pytest
import time
from server.core.security.api_rate_limiter import (
    APIRateLimiter,
    RateLimitConfig,
    SlidingWindow,
    TokenBucket
)


class TestSlidingWindow:
    """滑动窗口测试"""
    
    def test_basic_limit(self):
        """测试基本限流"""
        window = SlidingWindow(max_requests=10, window_seconds=60)
        
        # 前10次应该允许
        for i in range(10):
            assert window.is_allowed() is True
        
        # 第11次应该被拒绝
        assert window.is_allowed() is False
    
    def test_window_reset(self):
        """测试窗口重置"""
        window = SlidingWindow(max_requests=5, window_seconds=1)
        
        # 发送5个请求
        for i in range(5):
            assert window.is_allowed() is True
        
        # 应该被拒绝
        assert window.is_allowed() is False
        
        # 等待窗口重置
        time.sleep(1.1)
        
        # 应该可以继续请求
        assert window.is_allowed() is True
    
    def test_get_remaining(self):
        """测试获取剩余请求数"""
        window = SlidingWindow(max_requests=10, window_seconds=60)
        
        # 发送3个请求
        for i in range(3):
            window.is_allowed()
        
        assert window.get_remaining() == 7


class TestTokenBucket:
    """令牌桶测试"""
    
    def test_basic_consume(self):
        """测试基本消费"""
        bucket = TokenBucket(capacity=10, refill_rate=1.0)
        
        # 应该可以消费10个令牌
        for i in range(10):
            assert bucket.consume() is True
        
        # 第11次应该被拒绝
        assert bucket.consume() is False
    
    def test_refill(self):
        """测试令牌补充"""
        bucket = TokenBucket(capacity=10, refill_rate=10.0)  # 每秒10个
        
        # 消费所有令牌
        for i in range(10):
            bucket.consume()
        
        assert bucket.consume() is False
        
        # 等待补充
        time.sleep(0.2)
        
        # 应该可以消费
        assert bucket.consume() is True


class TestAPIRateLimiter:
    """API限流器测试"""
    
    def test_add_limit(self):
        """测试添加限流规则"""
        limiter = APIRateLimiter()
        
        limiter.add_limit("api/test", max_requests=5, window_seconds=60)
        
        # 应该可以请求5次
        for i in range(5):
            assert limiter.check_limit("api/test") is True
        
        # 第6次应该被拒绝
        assert limiter.check_limit("api/test") is False
    
    def test_default_limit(self):
        """测试默认限流"""
        limiter = APIRateLimiter()
        
        # 使用默认配置（100请求/60秒）
        for i in range(100):
            assert limiter.check_limit("api/new") is True
        
        assert limiter.check_limit("api/new") is False
    
    def test_get_remaining(self):
        """测试获取剩余请求数"""
        limiter = APIRateLimiter()
        limiter.add_limit("api/test", max_requests=10, window_seconds=60)
        
        # 发送3个请求
        for i in range(3):
            limiter.check_limit("api/test")
        
        assert limiter.get_remaining("api/test") == 7
    
    def test_reset_limit(self):
        """测试重置限流"""
        limiter = APIRateLimiter()
        limiter.add_limit("api/test", max_requests=5, window_seconds=60)
        
        # 触发限流
        for i in range(5):
            limiter.check_limit("api/test")
        
        assert limiter.check_limit("api/test") is False
        
        # 重置
        limiter.reset_limit("api/test")
        
        # 应该可以继续请求
        assert limiter.check_limit("api/test") is True

