"""
客户端连接管理器
"""
import asyncio
import time
from typing import Optional, Callable, List
from loguru import logger
import websockets
from websockets.exceptions import ConnectionClosed

from common.protocol.message import Message, MessageType, AuthRequest, AuthResponse, TunnelData
from common.utils.logger import setup_logger
from .server_selector import ServerSelector, ServerInfo


class ClientConnectionManager:
    """客户端连接管理器"""
    
    def __init__(self, server_urls: List[str], client_id: str, token: str):
        """
        初始化连接管理器
        
        Args:
            server_urls: 服务器URL列表，支持多个服务器
            client_id: 客户端ID
            token: 认证Token
        """
        self.server_selector = ServerSelector(server_urls)
        self.client_id = client_id
        self.token = token
        self.websocket: Optional[websockets.WebSocketClientProtocol] = None
        self.connected: bool = False
        self.reconnect_interval: int = 5
        self.max_reconnect_attempts: int = -1  # -1表示无限重连
        self.reconnect_attempts: int = 0
        self.on_message_callback: Optional[Callable] = None
        self.on_connect_callback: Optional[Callable] = None
        self.on_disconnect_callback: Optional[Callable] = None
        self._running: bool = False
        self._task: Optional[asyncio.Task] = None
        self._current_server: Optional[ServerInfo] = None
    
    async def connect(self):
        """连接到服务器"""
        # 选择最优服务器
        server = await self.server_selector.select_best_server()
        if not server:
            logger.error("没有可用的服务器")
            return False
        
        self._current_server = server
        ws_url = server.url.replace("http://", "ws://").replace("https://", "wss://")
        ws_url = f"{ws_url}/api/ws"
        
        try:
            logger.info(f"连接到服务器: {ws_url}")
            self.websocket = await websockets.connect(ws_url)
            self.connected = True
            
            # 发送认证请求
            auth_request = AuthRequest(
                token=self.token,
                client_id=self.client_id,
                version="1.0"
            )
            await self.websocket.send(auth_request.to_json())
            
            # 等待认证响应
            response_data = await self.websocket.recv()
            response = Message.from_json(response_data)
            
            if response.type == MessageType.AUTH_RESPONSE:
                data = response.data
                if data.get('success'):
                    logger.info("认证成功")
                    self.reconnect_attempts = 0
                    if self.on_connect_callback:
                        await self.on_connect_callback()
                    return True
                else:
                    logger.error(f"认证失败: {data.get('message')}")
                    return False
            
            return False
            
        except Exception as e:
            logger.error(f"连接失败: {e}")
            self.connected = False
            return False
    
    async def disconnect(self):
        """断开连接"""
        self._running = False
        if self.websocket:
            try:
                await self.websocket.close()
            except Exception as e:
                logger.error(f"关闭连接失败: {e}")
        self.connected = False
        if self.on_disconnect_callback:
            await self.on_disconnect_callback()
    
    async def send_message(self, message: Message):
        """发送消息"""
        if not self.connected or not self.websocket:
            logger.warning("未连接，无法发送消息")
            return False
        
        try:
            await self.websocket.send(message.to_json())
            return True
        except Exception as e:
            logger.error(f"发送消息失败: {e}")
            self.connected = False
            return False
    
    async def send_tunnel_data(self, tunnel_id: str, data: bytes):
        """发送隧道数据"""
        tunnel_data = TunnelData(tunnel_id=tunnel_id, data=data)
        return await self.send_message(tunnel_data)
    
    async def _message_loop(self):
        """消息循环"""
        while self._running and self.connected:
            try:
                if not self.websocket:
                    break
                
                data = await self.websocket.recv()
                message = Message.from_json(data)
                
                # 处理心跳
                if message.type == MessageType.HEARTBEAT_RESPONSE:
                    continue
                
                # 回调处理消息
                if self.on_message_callback:
                    await self.on_message_callback(message)
                
            except ConnectionClosed:
                logger.warning("连接已关闭")
                self.connected = False
                break
            except Exception as e:
                logger.error(f"接收消息失败: {e}")
                self.connected = False
                break
    
    async def start(self):
        """启动连接管理器"""
        self._running = True
        
        while self._running:
            if await self.connect():
                # 启动消息循环
                try:
                    await self._message_loop()
                except Exception as e:
                    logger.error(f"消息循环出错: {e}")
            
            # 重连逻辑
            if self._running:
                if self.max_reconnect_attempts > 0 and self.reconnect_attempts >= self.max_reconnect_attempts:
                    logger.error("达到最大重连次数，停止重连")
                    break
                
                self.reconnect_attempts += 1
                
                # 重新选择服务器（可能选择不同的服务器）
                logger.info(f"等待 {self.reconnect_interval} 秒后重连... (尝试 {self.reconnect_attempts})")
                await asyncio.sleep(self.reconnect_interval)
                
                # 重新选择服务器
                new_server = await self.server_selector.select_best_server()
                if new_server and new_server.url != (self._current_server.url if self._current_server else ""):
                    logger.info(f"切换到新服务器: {new_server.url}")
    
    def get_current_server_info(self) -> Optional[dict]:
        """获取当前服务器信息"""
        if self._current_server:
            return {
                'url': self._current_server.url,
                'name': self._current_server.name,
                'region': self._current_server.region,
                'latency': self._current_server.latency,
                'status': self._current_server.status,
            }
        return None
    
    def get_all_servers_status(self) -> List[dict]:
        """获取所有服务器状态"""
        return self.server_selector.get_all_servers_status()
    
    async def stop(self):
        """停止连接管理器"""
        await self.disconnect()
    
    def set_on_message(self, callback: Callable):
        """设置消息回调"""
        self.on_message_callback = callback
    
    def set_on_connect(self, callback: Callable):
        """设置连接回调"""
        self.on_connect_callback = callback
    
    def set_on_disconnect(self, callback: Callable):
        """设置断开回调"""
        self.on_disconnect_callback = callback

