"""
节点管理器 - 集群节点发现和管理
"""
import asyncio
import aiohttp
import time
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from datetime import datetime
from loguru import logger
import json


@dataclass
class NodeInfo:
    """节点信息"""
    node_id: str
    address: str
    port: int
    status: str  # online, offline, unknown
    last_seen: float
    role: str = "worker"  # master, worker
    load: float = 0.0  # 负载（0-1）
    connections: int = 0
    metadata: Dict = None


class NodeManager:
    """节点管理器"""
    
    def __init__(self, node_id: str, listen_address: str = "0.0.0.0", listen_port: int = 8080):
        self.node_id = node_id
        self.listen_address = listen_address
        self.listen_port = listen_port
        self._nodes: Dict[str, NodeInfo] = {}
        self._master_node: Optional[NodeInfo] = None
        self._is_master = False
        self._cluster_enabled = False
        self._heartbeat_interval = 30  # 心跳间隔（秒）
        self._node_timeout = 90  # 节点超时（秒）
    
    def enable_cluster(self, enabled: bool = True):
        """启用/禁用集群模式"""
        self._cluster_enabled = enabled
        logger.info(f"集群模式{'启用' if enabled else '禁用'}")
    
    def add_node(self, node_id: str, address: str, port: int, role: str = "worker"):
        """添加节点"""
        node = NodeInfo(
            node_id=node_id,
            address=address,
            port=port,
            status="unknown",
            last_seen=time.time(),
            role=role,
            metadata={}
        )
        self._nodes[node_id] = node
        
        if role == "master":
            self._master_node = node
        
        logger.info(f"添加节点: {node_id} ({address}:{port})")
    
    def remove_node(self, node_id: str):
        """移除节点"""
        if node_id in self._nodes:
            del self._nodes[node_id]
            if self._master_node and self._master_node.node_id == node_id:
                self._master_node = None
            logger.info(f"移除节点: {node_id}")
    
    def get_node(self, node_id: str) -> Optional[NodeInfo]:
        """获取节点信息"""
        return self._nodes.get(node_id)
    
    def get_all_nodes(self) -> List[NodeInfo]:
        """获取所有节点"""
        return list(self._nodes.values())
    
    def get_online_nodes(self) -> List[NodeInfo]:
        """获取在线节点"""
        current_time = time.time()
        return [
            node for node in self._nodes.values()
            if node.status == "online" and (current_time - node.last_seen) < self._node_timeout
        ]
    
    def get_master_node(self) -> Optional[NodeInfo]:
        """获取主节点"""
        return self._master_node
    
    async def send_heartbeat(self, node_id: str, load: float = 0.0, connections: int = 0):
        """发送心跳到其他节点"""
        if not self._cluster_enabled:
            return
        
        current_time = time.time()
        node_info = {
            'node_id': self.node_id,
            'address': self.listen_address,
            'port': self.listen_port,
            'status': 'online',
            'timestamp': current_time,
            'load': load,
            'connections': connections,
        }
        
        # 发送心跳到所有其他节点
        tasks = []
        for node in self._nodes.values():
            if node.node_id != self.node_id:
                tasks.append(self._send_heartbeat_to_node(node, node_info))
        
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)
    
    async def _send_heartbeat_to_node(self, node: NodeInfo, node_info: Dict):
        """发送心跳到指定节点"""
        try:
            url = f"http://{node.address}:{node.port}/api/cluster/heartbeat"
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
                async with session.post(url, json=node_info) as response:
                    if response.status == 200:
                        # 更新节点状态
                        node.status = "online"
                        node.last_seen = time.time()
        except Exception as e:
            logger.debug(f"发送心跳到节点 {node.node_id} 失败: {e}")
            node.status = "offline"
    
    async def receive_heartbeat(self, node_info: Dict):
        """接收其他节点的心跳"""
        node_id = node_info.get('node_id')
        if not node_id or node_id == self.node_id:
            return
        
        if node_id not in self._nodes:
            # 新节点，添加到集群
            self.add_node(
                node_id=node_id,
                address=node_info.get('address'),
                port=node_info.get('port'),
                role=node_info.get('role', 'worker')
            )
        
        node = self._nodes[node_id]
        node.status = "online"
        node.last_seen = node_info.get('timestamp', time.time())
        node.load = node_info.get('load', 0.0)
        node.connections = node_info.get('connections', 0)
    
    async def start_heartbeat_loop(self):
        """启动心跳循环"""
        while self._cluster_enabled:
            try:
                # 获取当前负载
                from ..monitor.monitor import SystemMonitor
                # 这里需要从系统监控获取负载信息
                load = 0.0  # TODO: 从系统监控获取
                connections = 0  # TODO: 从连接管理器获取
                
                await self.send_heartbeat(self.node_id, load, connections)
                
                # 检查节点超时
                await self._check_node_timeout()
                
                await asyncio.sleep(self._heartbeat_interval)
            except Exception as e:
                logger.error(f"心跳循环错误: {e}")
                await asyncio.sleep(self._heartbeat_interval)
    
    async def _check_node_timeout(self):
        """检查节点超时"""
        current_time = time.time()
        timeout_nodes = []
        
        for node_id, node in self._nodes.items():
            if node_id != self.node_id:
                if (current_time - node.last_seen) > self._node_timeout:
                    timeout_nodes.append(node_id)
        
        for node_id in timeout_nodes:
            node = self._nodes[node_id]
            node.status = "offline"
            logger.warning(f"节点 {node_id} 超时")
    
    def select_best_node(self, exclude_node_id: Optional[str] = None) -> Optional[NodeInfo]:
        """选择最优节点（负载最低）"""
        online_nodes = [
            node for node in self.get_online_nodes()
            if node.node_id != exclude_node_id
        ]
        
        if not online_nodes:
            return None
        
        # 选择负载最低的节点
        return min(online_nodes, key=lambda n: n.load)
    
    def get_cluster_stats(self) -> Dict:
        """获取集群统计"""
        online_nodes = self.get_online_nodes()
        total_load = sum(node.load for node in online_nodes)
        total_connections = sum(node.connections for node in online_nodes)
        
        return {
            'node_id': self.node_id,
            'total_nodes': len(self._nodes),
            'online_nodes': len(online_nodes),
            'master_node': self._master_node.node_id if self._master_node else None,
            'is_master': self._is_master,
            'total_load': total_load,
            'total_connections': total_connections,
            'nodes': [
                {
                    'node_id': node.node_id,
                    'address': f"{node.address}:{node.port}",
                    'status': node.status,
                    'load': node.load,
                    'connections': node.connections,
                }
                for node in online_nodes
            ]
        }

