"""
故障转移 - 自动故障检测和切换
"""
import asyncio
import time
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from loguru import logger
import aiohttp


class FailoverStrategy(str, Enum):
    """故障转移策略"""
    ACTIVE_PASSIVE = "active_passive"  # 主备模式
    ACTIVE_ACTIVE = "active_active"  # 双活模式
    ROUND_ROBIN = "round_robin"  # 轮询
    LEAST_CONNECTIONS = "least_connections"  # 最少连接


@dataclass
class NodeHealth:
    """节点健康状态"""
    node_id: str
    status: str  # healthy, unhealthy, degraded
    last_check: datetime
    consecutive_failures: int = 0
    response_time: float = 0.0
    error_rate: float = 0.0


class FailoverManager:
    """故障转移管理器"""
    
    def __init__(self, strategy: FailoverStrategy = FailoverStrategy.ACTIVE_PASSIVE):
        self.strategy = strategy
        self.nodes: Dict[str, Dict] = {}
        self.node_health: Dict[str, NodeHealth] = {}
        self.active_node: Optional[str] = None
        self.backup_nodes: List[str] = []
        self._running = False
        self._monitor_task: Optional[asyncio.Task] = None
        self.failover_callbacks: List[Callable] = []
        self.check_interval = 10  # 检查间隔（秒）
        self.failure_threshold = 3  # 失败阈值
    
    def add_node(self, node_id: str, address: str, port: int, 
                priority: int = 0, is_primary: bool = False):
        """添加节点"""
        self.nodes[node_id] = {
            'address': address,
            'port': port,
            'priority': priority,
            'is_primary': is_primary,
            'added_at': datetime.now()
        }
        
        self.node_health[node_id] = NodeHealth(
            node_id=node_id,
            status="unknown",
            last_check=datetime.now()
        )
        
        if is_primary:
            self.active_node = node_id
        
        logger.info(f"添加节点: {node_id} ({address}:{port})")
    
    async def check_node_health(self, node_id: str) -> bool:
        """检查节点健康"""
        node = self.nodes.get(node_id)
        if not node:
            return False
        
        try:
            url = f"http://{node['address']}:{node['port']}/api/health"
            start_time = time.time()
            
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
                async with session.get(url) as resp:
                    response_time = time.time() - start_time
                    
                    if resp.status == 200:
                        health = self.node_health[node_id]
                        health.status = "healthy"
                        health.last_check = datetime.now()
                        health.response_time = response_time
                        health.consecutive_failures = 0
                        return True
                    else:
                        await self._record_failure(node_id)
                        return False
        except Exception as e:
            logger.debug(f"节点健康检查失败: {node_id} - {e}")
            await self._record_failure(node_id)
            return False
    
    async def _record_failure(self, node_id: str):
        """记录失败"""
        health = self.node_health[node_id]
        health.consecutive_failures += 1
        health.last_check = datetime.now()
        
        if health.consecutive_failures >= self.failure_threshold:
            health.status = "unhealthy"
            logger.warning(f"节点标记为不健康: {node_id} (连续失败: {health.consecutive_failures})")
            
            # 如果是不健康的主节点，触发故障转移
            if node_id == self.active_node:
                await self.trigger_failover(node_id)
        else:
            health.status = "degraded"
    
    async def trigger_failover(self, failed_node_id: str):
        """触发故障转移"""
        logger.warning(f"触发故障转移: {failed_node_id}")
        
        # 选择备用节点
        backup_node = self._select_backup_node(failed_node_id)
        
        if not backup_node:
            logger.error("未找到可用的备用节点")
            return False
        
        # 执行故障转移
        old_active = self.active_node
        self.active_node = backup_node
        
        logger.info(f"故障转移完成: {old_active} -> {backup_node}")
        
        # 调用回调
        for callback in self.failover_callbacks:
            try:
                if asyncio.iscoroutinefunction(callback):
                    await callback(old_active, backup_node)
                else:
                    callback(old_active, backup_node)
            except Exception as e:
                logger.error(f"故障转移回调执行失败: {e}")
        
        return True
    
    def _select_backup_node(self, failed_node_id: str) -> Optional[str]:
        """选择备用节点"""
        if self.strategy == FailoverStrategy.ACTIVE_PASSIVE:
            # 选择优先级最高的健康节点
            healthy_nodes = [
                (node_id, node_info)
                for node_id, node_info in self.nodes.items()
                if node_id != failed_node_id and
                self.node_health[node_id].status == "healthy"
            ]
            
            if healthy_nodes:
                healthy_nodes.sort(key=lambda x: x[1]['priority'], reverse=True)
                return healthy_nodes[0][0]
        
        elif self.strategy == FailoverStrategy.ROUND_ROBIN:
            # 轮询选择
            healthy_nodes = [
                node_id for node_id in self.nodes.keys()
                if node_id != failed_node_id and
                self.node_health[node_id].status == "healthy"
            ]
            
            if healthy_nodes:
                # 简单的轮询实现
                if self.backup_nodes:
                    current_index = self.backup_nodes.index(self.backup_nodes[0]) if self.backup_nodes[0] in healthy_nodes else 0
                else:
                    current_index = 0
                
                return healthy_nodes[current_index % len(healthy_nodes)]
        
        elif self.strategy == FailoverStrategy.LEAST_CONNECTIONS:
            # 选择连接数最少的节点
            # 这里简化实现，实际应该从节点获取连接数
            healthy_nodes = [
                node_id for node_id in self.nodes.keys()
                if node_id != failed_node_id and
                self.node_health[node_id].status == "healthy"
            ]
            
            if healthy_nodes:
                return healthy_nodes[0]  # 简化实现
        
        return None
    
    async def monitor_loop(self):
        """监控循环"""
        while self._running:
            try:
                # 检查所有节点
                for node_id in self.nodes.keys():
                    await self.check_node_health(node_id)
                
                await asyncio.sleep(self.check_interval)
            except Exception as e:
                logger.error(f"监控循环错误: {e}")
                await asyncio.sleep(self.check_interval)
    
    async def start_monitoring(self):
        """启动监控"""
        self._running = True
        self._monitor_task = asyncio.create_task(self.monitor_loop())
        logger.info("故障转移监控已启动")
    
    async def stop_monitoring(self):
        """停止监控"""
        self._running = False
        
        if self._monitor_task:
            self._monitor_task.cancel()
            try:
                await self._monitor_task
            except asyncio.CancelledError:
                pass
    
    def add_failover_callback(self, callback: Callable):
        """添加故障转移回调"""
        self.failover_callbacks.append(callback)
    
    def get_active_node(self) -> Optional[str]:
        """获取当前活跃节点"""
        return self.active_node
    
    def get_node_status(self) -> Dict:
        """获取节点状态"""
        return {
            'active_node': self.active_node,
            'strategy': self.strategy.value,
            'nodes': {
                node_id: {
                    'status': health.status,
                    'last_check': health.last_check.isoformat(),
                    'consecutive_failures': health.consecutive_failures,
                    'response_time': health.response_time,
                }
                for node_id, health in self.node_health.items()
            }
        }

