"""
连接管理器
"""
import asyncio
import time
from typing import Dict, Optional, Set
from datetime import datetime
from loguru import logger

from common.models.client import Client, ClientStatus
from common.models.session import Session, SessionStatus
from common.protocol.message import Message, MessageType, AuthRequest, AuthResponse


class ConnectionManager:
    """连接管理器 - 管理所有客户端连接"""
    
    def __init__(self):
        self._connections: Dict[str, 'ClientConnection'] = {}
        self._client_sessions: Dict[str, Session] = {}
        self._max_connections: int = 10000
        
    async def add_connection(self, connection: 'ClientConnection') -> bool:
        """添加连接"""
        if len(self._connections) >= self._max_connections:
            logger.warning(f"连接数已达上限: {self._max_connections}")
            return False
        
        client_id = connection.client_id
        if client_id in self._connections:
            logger.warning(f"客户端 {client_id} 已存在连接，关闭旧连接")
            await self.remove_connection(client_id)
        
        self._connections[client_id] = connection
        
        # 创建会话
        session = Session(
            id=f"session_{int(time.time())}_{client_id}",
            client_id=client_id,
            status=SessionStatus.CONNECTED,
            created_at=datetime.now(),
            updated_at=datetime.now(),
            remote_addr=connection.remote_addr,
            protocol=connection.protocol,
        )
        self._client_sessions[client_id] = session
        
        logger.info(f"客户端 {client_id} 连接成功，当前连接数: {len(self._connections)}")
        return True
    
    async def remove_connection(self, client_id: str):
        """移除连接"""
        if client_id in self._connections:
            connection = self._connections[client_id]
            await connection.close()
            del self._connections[client_id]
        
        if client_id in self._client_sessions:
            session = self._client_sessions[client_id]
            session.status = SessionStatus.DISCONNECTED
            session.updated_at = datetime.now()
            del self._client_sessions[client_id]
        
        logger.info(f"客户端 {client_id} 断开连接，当前连接数: {len(self._connections)}")
    
    def get_connection(self, client_id: str) -> Optional['ClientConnection']:
        """获取连接"""
        return self._connections.get(client_id)
    
    def is_connected(self, client_id: str) -> bool:
        """检查是否连接"""
        return client_id in self._connections
    
    def get_all_connections(self) -> Dict[str, 'ClientConnection']:
        """获取所有连接"""
        return self._connections.copy()
    
    def get_connection_count(self) -> int:
        """获取连接数"""
        return len(self._connections)
    
    async def broadcast_message(self, message: Message, exclude_client_id: Optional[str] = None):
        """广播消息"""
        tasks = []
        for client_id, connection in self._connections.items():
            if exclude_client_id and client_id == exclude_client_id:
                continue
            tasks.append(connection.send_message(message))
        
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)
    
    async def send_to_client(self, client_id: str, message: Message) -> bool:
        """发送消息给指定客户端"""
        connection = self.get_connection(client_id)
        if connection:
            await connection.send_message(message)
            return True
        return False


class ClientConnection:
    """客户端连接"""
    
    def __init__(self, websocket, remote_addr: str, protocol: str = "websocket"):
        self.websocket = websocket
        self.remote_addr = remote_addr
        self.protocol = protocol
        self.client_id: Optional[str] = None
        self.client: Optional[Client] = None
        self.authenticated: bool = False
        self.last_heartbeat: float = time.time()
        self.created_at: float = time.time()
    
    async def send_message(self, message: Message):
        """发送消息"""
        try:
            await self.websocket.send(message.to_json())
        except Exception as e:
            logger.error(f"发送消息失败: {e}")
            raise
    
    async def receive_message(self) -> Optional[Message]:
        """接收消息"""
        try:
            data = await self.websocket.recv()
            if isinstance(data, str):
                return Message.from_json(data)
            else:
                return Message.from_json(data.decode('utf-8'))
        except Exception as e:
            logger.error(f"接收消息失败: {e}")
            return None
    
    async def close(self):
        """关闭连接"""
        try:
            await self.websocket.close()
        except Exception as e:
            logger.error(f"关闭连接失败: {e}")
    
    def update_heartbeat(self):
        """更新心跳时间"""
        self.last_heartbeat = time.time()
    
    def is_alive(self, timeout: int = 60) -> bool:
        """检查连接是否存活"""
        return (time.time() - self.last_heartbeat) < timeout

