"""
WebSocket处理器
"""
import asyncio
import json
from typing import Optional
from loguru import logger

from fastapi import WebSocket, WebSocketDisconnect
from common.protocol.message import Message, MessageType, AuthRequest, AuthResponse, TunnelData
from common.models.session import SessionStatus

from .connection_manager import ConnectionManager, ClientConnection
from ..auth.auth_manager import AuthManager
from ..tunnel.tunnel_manager import TunnelManager


class WebSocketHandler:
    """WebSocket连接处理器"""
    
    def __init__(
        self,
        connection_manager: ConnectionManager,
        auth_manager: AuthManager,
        tunnel_manager: TunnelManager,
    ):
        self.connection_manager = connection_manager
        self.auth_manager = auth_manager
        self.tunnel_manager = tunnel_manager
    
    async def handle_connection(self, websocket: WebSocket, client_host: str):
        """处理WebSocket连接"""
        await websocket.accept()
        
        connection = ClientConnection(
            websocket=websocket,
            remote_addr=client_host,
            protocol="websocket"
        )
        
        try:
            # 等待认证
            authenticated = await self._handle_authentication(connection)
            if not authenticated:
                await websocket.close()
                return
            
            # 添加到连接管理器
            await self.connection_manager.add_connection(connection)
            
            # 处理消息循环
            await self._handle_message_loop(connection)
            
        except WebSocketDisconnect:
            logger.info(f"客户端 {connection.client_id} 断开连接")
        except Exception as e:
            logger.error(f"处理连接时出错: {e}")
        finally:
            if connection.client_id:
                await self.connection_manager.remove_connection(connection.client_id)
    
    async def _handle_authentication(self, connection: ClientConnection) -> bool:
        """处理认证"""
        try:
            # 接收认证消息
            message = await connection.receive_message()
            if not message or message.type != MessageType.AUTH_REQUEST:
                logger.warning("未收到认证请求")
                return False
            
            data = message.data
            token = data.get('token')
            client_id = data.get('client_id')
            
            if not token or not client_id:
                logger.warning("认证信息不完整")
                await connection.send_message(AuthResponse(success=False, message="认证信息不完整"))
                return False
            
            # 验证客户端
            client = await self.auth_manager.authenticate_client(token, client_id)
            if not client:
                logger.warning(f"客户端认证失败: {client_id}")
                await connection.send_message(AuthResponse(success=False, message="认证失败"))
                return False
            
            # 设置连接信息
            connection.client_id = client_id
            connection.client = client
            connection.authenticated = True
            
            # 发送认证成功响应
            await connection.send_message(AuthResponse(
                success=True,
                message="认证成功",
                client_id=client_id
            ))
            
            logger.info(f"客户端 {client_id} 认证成功")
            return True
            
        except Exception as e:
            logger.error(f"认证过程出错: {e}")
            return False
    
    async def _handle_message_loop(self, connection: ClientConnection):
        """处理消息循环"""
        while True:
            try:
                message = await connection.receive_message()
                if not message:
                    break
                
                # 更新心跳
                if message.type == MessageType.HEARTBEAT:
                    connection.update_heartbeat()
                    await connection.send_message(Message(
                        type=MessageType.HEARTBEAT_RESPONSE,
                        data={'timestamp': message.timestamp}
                    ))
                    continue
                
                # 处理隧道数据
                if message.type == MessageType.TUNNEL_DATA:
                    await self._handle_tunnel_data(connection, message)
                    continue
                
                # 处理其他控制消息
                await self._handle_control_message(connection, message)
                
            except Exception as e:
                logger.error(f"处理消息时出错: {e}")
                break
    
    async def _handle_tunnel_data(self, connection: ClientConnection, message: Message):
        """处理隧道数据"""
        tunnel_id = message.data.get('tunnel_id')
        if not tunnel_id:
            return
        
        # 获取数据
        tunnel_data = TunnelData.from_message(message)
        data = tunnel_data.get_data()
        
        # 转发数据
        await self.tunnel_manager.forward_data(tunnel_id, data)
    
    async def _handle_control_message(self, connection: ClientConnection, message: Message):
        """处理控制消息"""
        # TODO: 实现控制消息处理
        logger.debug(f"收到控制消息: {message.type}")

