"""
熔断器实现
"""
import time
from enum import Enum
from typing import Callable, Optional, Dict
from dataclasses import dataclass
from loguru import logger


class CircuitState(str, Enum):
    """熔断器状态"""
    CLOSED = "closed"      # 关闭（正常）
    OPEN = "open"          # 打开（熔断）
    HALF_OPEN = "half_open"  # 半开（尝试恢复）


@dataclass
class CircuitBreakerConfig:
    """熔断器配置"""
    failure_threshold: int = 5  # 失败阈值
    success_threshold: int = 2  # 半开状态下的成功阈值
    timeout: float = 60.0  # 熔断超时时间（秒）
    expected_exception: type = Exception  # 预期异常类型


class CircuitBreaker:
    """熔断器"""
    
    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.name = name
        self.config = config or CircuitBreakerConfig()
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time: Optional[float] = None
        self.next_attempt_time: Optional[float] = None
    
    def call(self, func: Callable, *args, **kwargs):
        """执行函数，带熔断保护"""
        # 检查是否可以执行
        if not self._can_execute():
            raise Exception(f"熔断器 {self.name} 处于 {self.state.value} 状态，拒绝执行")
        
        try:
            result = func(*args, **kwargs)
            self._on_success()
            return result
        except self.config.expected_exception as e:
            self._on_failure()
            raise e
    
    async def call_async(self, func: Callable, *args, **kwargs):
        """异步执行函数，带熔断保护"""
        # 检查是否可以执行
        if not self._can_execute():
            raise Exception(f"熔断器 {self.name} 处于 {self.state.value} 状态，拒绝执行")
        
        try:
            result = await func(*args, **kwargs)
            self._on_success()
            return result
        except self.config.expected_exception as e:
            self._on_failure()
            raise e
    
    def _can_execute(self) -> bool:
        """检查是否可以执行"""
        current_time = time.time()
        
        if self.state == CircuitState.CLOSED:
            return True
        
        elif self.state == CircuitState.OPEN:
            # 检查是否超时，可以尝试恢复
            if self.next_attempt_time and current_time >= self.next_attempt_time:
                self.state = CircuitState.HALF_OPEN
                self.success_count = 0
                logger.info(f"熔断器 {self.name} 进入半开状态")
                return True
            return False
        
        elif self.state == CircuitState.HALF_OPEN:
            return True
        
        return False
    
    def _on_success(self):
        """成功回调"""
        if self.state == CircuitState.HALF_OPEN:
            self.success_count += 1
            if self.success_count >= self.config.success_threshold:
                self.state = CircuitState.CLOSED
                self.failure_count = 0
                self.success_count = 0
                logger.info(f"熔断器 {self.name} 恢复正常（关闭状态）")
        elif self.state == CircuitState.CLOSED:
            self.failure_count = 0
    
    def _on_failure(self):
        """失败回调"""
        self.failure_count += 1
        self.last_failure_time = time.time()
        
        if self.state == CircuitState.HALF_OPEN:
            # 半开状态下失败，立即熔断
            self.state = CircuitState.OPEN
            self.next_attempt_time = time.time() + self.config.timeout
            logger.warning(f"熔断器 {self.name} 在半开状态下失败，重新熔断")
        
        elif self.state == CircuitState.CLOSED:
            # 关闭状态下，检查是否达到失败阈值
            if self.failure_count >= self.config.failure_threshold:
                self.state = CircuitState.OPEN
                self.next_attempt_time = time.time() + self.config.timeout
                logger.warning(f"熔断器 {self.name} 达到失败阈值，进入熔断状态")
    
    def reset(self):
        """重置熔断器"""
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time = None
        self.next_attempt_time = None
        logger.info(f"熔断器 {self.name} 已重置")
    
    def get_state(self) -> Dict:
        """获取状态"""
        return {
            'name': self.name,
            'state': self.state.value,
            'failure_count': self.failure_count,
            'success_count': self.success_count,
            'last_failure_time': self.last_failure_time,
            'next_attempt_time': self.next_attempt_time,
        }


class CircuitBreakerManager:
    """熔断器管理器"""
    
    def __init__(self):
        self._breakers: Dict[str, CircuitBreaker] = {}
    
    def get_breaker(self, name: str, config: CircuitBreakerConfig = None) -> CircuitBreaker:
        """获取或创建熔断器"""
        if name not in self._breakers:
            self._breakers[name] = CircuitBreaker(name, config)
        return self._breakers[name]
    
    def reset_breaker(self, name: str):
        """重置熔断器"""
        if name in self._breakers:
            self._breakers[name].reset()
    
    def get_all_breakers(self) -> Dict[str, Dict]:
        """获取所有熔断器状态"""
        return {
            name: breaker.get_state()
            for name, breaker in self._breakers.items()
        }

