"""
熔断器单元测试
"""
import pytest
import time
from server.core.security.circuit_breaker import (
    CircuitBreaker,
    CircuitBreakerConfig,
    CircuitState,
    CircuitBreakerManager
)


class TestCircuitBreaker:
    """熔断器测试"""
    
    def test_initial_state(self):
        """测试初始状态"""
        breaker = CircuitBreaker("test")
        assert breaker.state == CircuitState.CLOSED
        assert breaker.failure_count == 0
    
    def test_success_call(self):
        """测试成功调用"""
        breaker = CircuitBreaker("test")
        
        def success_func():
            return "success"
        
        result = breaker.call(success_func)
        assert result == "success"
        assert breaker.state == CircuitState.CLOSED
    
    def test_failure_threshold(self):
        """测试失败阈值"""
        config = CircuitBreakerConfig(failure_threshold=3)
        breaker = CircuitBreaker("test", config)
        
        def fail_func():
            raise ValueError("test error")
        
        # 触发失败阈值
        for i in range(3):
            try:
                breaker.call(fail_func)
            except ValueError:
                pass
        
        assert breaker.state == CircuitState.OPEN
        assert breaker.failure_count == 3
    
    def test_circuit_open_rejects_calls(self):
        """测试熔断打开时拒绝调用"""
        config = CircuitBreakerConfig(failure_threshold=2, timeout=10.0)
        breaker = CircuitBreaker("test", config)
        
        def fail_func():
            raise Exception("error")
        
        # 触发熔断
        for i in range(2):
            try:
                breaker.call(fail_func)
            except Exception:
                pass
        
        # 应该拒绝新的调用
        with pytest.raises(Exception, match="拒绝执行"):
            breaker.call(lambda: "should not execute")
    
    def test_half_open_recovery(self):
        """测试半开状态恢复"""
        config = CircuitBreakerConfig(
            failure_threshold=2,
            success_threshold=2,
            timeout=0.1
        )
        breaker = CircuitBreaker("test", config)
        
        def fail_func():
            raise Exception("error")
        
        # 触发熔断
        for i in range(2):
            try:
                breaker.call(fail_func)
            except Exception:
                pass
        
        assert breaker.state == CircuitState.OPEN
        
        # 等待超时
        time.sleep(0.2)
        
        # 成功调用两次
        def success_func():
            return "success"
        
        result1 = breaker.call(success_func)
        assert result1 == "success"
        assert breaker.state == CircuitState.HALF_OPEN
        
        result2 = breaker.call(success_func)
        assert result2 == "success"
        assert breaker.state == CircuitState.CLOSED
    
    def test_reset(self):
        """测试重置"""
        breaker = CircuitBreaker("test")
        
        def fail_func():
            raise Exception("error")
        
        # 触发失败
        for i in range(5):
            try:
                breaker.call(fail_func)
            except Exception:
                pass
        
        # 重置
        breaker.reset()
        
        assert breaker.state == CircuitState.CLOSED
        assert breaker.failure_count == 0


class TestCircuitBreakerManager:
    """熔断器管理器测试"""
    
    def test_get_breaker(self):
        """测试获取熔断器"""
        manager = CircuitBreakerManager()
        
        breaker1 = manager.get_breaker("test1")
        breaker2 = manager.get_breaker("test1")
        
        assert breaker1 is breaker2  # 应该返回同一个实例
    
    def test_multiple_breakers(self):
        """测试多个熔断器"""
        manager = CircuitBreakerManager()
        
        breaker1 = manager.get_breaker("service1")
        breaker2 = manager.get_breaker("service2")
        
        assert breaker1.name == "service1"
        assert breaker2.name == "service2"
        assert breaker1 is not breaker2

