"""
隧道管理器
"""
import asyncio
from typing import Dict, Optional, List
from datetime import datetime
from loguru import logger

from common.models.tunnel import Tunnel, TunnelType, TunnelStatus
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .forward_engine import ForwardEngine
from common.protocol.message import Message, MessageType, TunnelData
from .forward_engine import ForwardEngine


class TunnelManager:
    """隧道管理器"""
    
    def __init__(self):
        self._tunnels: Dict[str, Tunnel] = {}
        self._client_tunnels: Dict[str, List[str]] = {}  # client_id -> tunnel_ids
        self._tunnel_connections: Dict[str, 'TunnelConnection'] = {}
        self._port_pool: set = set(range(10000, 60000))  # 可用端口池
        self._used_ports: set = set()
        self._forward_engine: Optional[ForwardEngine] = None
        self._listening_servers: Dict[str, asyncio.Server] = {}
    
    def set_forward_engine(self, forward_engine: 'ForwardEngine'):
        """设置转发引擎"""
        self._forward_engine = forward_engine
    
    def create_tunnel(
        self,
        client_id: str,
        name: str,
        tunnel_type: TunnelType,
        local_host: str,
        local_port: int,
        remote_port: Optional[int] = None,
        domain: Optional[str] = None,
    ) -> Tunnel:
        """创建隧道"""
        import uuid
        
        # 分配远程端口
        if remote_port is None:
            remote_port = self._allocate_port()
        
        if remote_port in self._used_ports:
            raise ValueError(f"端口 {remote_port} 已被使用")
        
        tunnel = Tunnel(
            id=str(uuid.uuid4()),
            client_id=client_id,
            name=name,
            tunnel_type=tunnel_type,
            local_host=local_host,
            local_port=local_port,
            remote_port=remote_port,
            domain=domain,
            status=TunnelStatus.INACTIVE,
            created_at=datetime.now(),
            updated_at=datetime.now(),
        )
        
        self._tunnels[tunnel.id] = tunnel
        self._used_ports.add(remote_port)
        
        if client_id not in self._client_tunnels:
            self._client_tunnels[client_id] = []
        self._client_tunnels[client_id].append(tunnel.id)
        
        logger.info(f"创建隧道: {tunnel.id} ({tunnel_type.value}) {local_host}:{local_port} -> :{remote_port}")
        return tunnel
    
    def remove_tunnel(self, tunnel_id: str):
        """移除隧道"""
        if tunnel_id not in self._tunnels:
            return
        
        tunnel = self._tunnels[tunnel_id]
        
        # 释放端口
        if tunnel.remote_port:
            self._used_ports.discard(tunnel.remote_port)
        
        # 关闭连接
        if tunnel_id in self._tunnel_connections:
            connection = self._tunnel_connections[tunnel_id]
            asyncio.create_task(connection.close())
            del self._tunnel_connections[tunnel_id]
        
        # 从客户端隧道列表中移除
        if tunnel.client_id in self._client_tunnels:
            self._client_tunnels[tunnel.client_id].remove(tunnel_id)
        
        del self._tunnels[tunnel_id]
        logger.info(f"移除隧道: {tunnel_id}")
    
    def get_tunnel(self, tunnel_id: str) -> Optional[Tunnel]:
        """获取隧道"""
        return self._tunnels.get(tunnel_id)
    
    def get_client_tunnels(self, client_id: str) -> List[Tunnel]:
        """获取客户端的所有隧道"""
        tunnel_ids = self._client_tunnels.get(client_id, [])
        return [self._tunnels[tid] for tid in tunnel_ids if tid in self._tunnels]
    
    def get_tunnel_by_port(self, port: int) -> Optional[Tunnel]:
        """根据端口获取隧道"""
        for tunnel in self._tunnels.values():
            if tunnel.remote_port == port:
                return tunnel
        return None
    
    async def activate_tunnel(self, tunnel_id: str):
        """激活隧道"""
        if tunnel_id in self._tunnels:
            tunnel = self._tunnels[tunnel_id]
            tunnel.status = TunnelStatus.ACTIVE
            tunnel.updated_at = datetime.now()
            
            # 启动监听服务器
            await self._start_listening_server(tunnel)
            
            logger.info(f"激活隧道: {tunnel_id}")
    
    async def _start_listening_server(self, tunnel: Tunnel):
        """启动监听服务器"""
        if not self._forward_engine:
            logger.warning("转发引擎未设置")
            return
        
        if tunnel.id in self._listening_servers:
            logger.warning(f"隧道 {tunnel.id} 已在监听")
            return
        
        try:
            # 在远程端口上监听
            server = await asyncio.start_server(
                lambda r, w: self._forward_engine.handle_tunnel_connection(tunnel.id, r, w),
                '0.0.0.0',
                tunnel.remote_port
            )
            self._listening_servers[tunnel.id] = server
            logger.info(f"隧道 {tunnel.id} 开始监听端口 {tunnel.remote_port}")
        except Exception as e:
            logger.error(f"启动监听服务器失败: {e}")
    
    async def deactivate_tunnel(self, tunnel_id: str):
        """停用隧道"""
        if tunnel_id in self._tunnels:
            tunnel = self._tunnels[tunnel_id]
            tunnel.status = TunnelStatus.INACTIVE
            tunnel.updated_at = datetime.now()
            
            # 停止监听服务器
            if tunnel_id in self._listening_servers:
                server = self._listening_servers[tunnel_id]
                server.close()
                await server.wait_closed()
                del self._listening_servers[tunnel_id]
            
            # 如果是UDP隧道，停止UDP代理
            if tunnel.tunnel_type == TunnelType.UDP and self._forward_engine:
                await self._forward_engine._udp_proxy.stop_udp_proxy(tunnel_id)
            
            logger.info(f"停用隧道: {tunnel_id}")
    
    def _allocate_port(self) -> int:
        """分配端口"""
        available_ports = self._port_pool - self._used_ports
        if not available_ports:
            raise RuntimeError("没有可用端口")
        port = min(available_ports)
        return port
    
    async def forward_data(self, tunnel_id: str, data: bytes):
        """转发数据"""
        if tunnel_id not in self._tunnels:
            logger.warning(f"隧道不存在: {tunnel_id}")
            return
        
        tunnel = self._tunnels[tunnel_id]
        if tunnel.status != TunnelStatus.ACTIVE:
            logger.warning(f"隧道未激活: {tunnel_id}")
            return
        
        # 这里应该将数据转发到对应的隧道连接
        if tunnel_id in self._tunnel_connections:
            connection = self._tunnel_connections[tunnel_id]
            await connection.send_data(data)


class TunnelConnection:
    """隧道连接"""
    
    def __init__(self, tunnel_id: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        self.tunnel_id = tunnel_id
        self.reader = reader
        self.writer = writer
    
    async def send_data(self, data: bytes):
        """发送数据"""
        try:
            self.writer.write(data)
            await self.writer.drain()
        except Exception as e:
            logger.error(f"发送隧道数据失败: {e}")
    
    async def receive_data(self) -> Optional[bytes]:
        """接收数据"""
        try:
            return await self.reader.read(4096)
        except Exception as e:
            logger.error(f"接收隧道数据失败: {e}")
            return None
    
    async def close(self):
        """关闭连接"""
        try:
            self.writer.close()
            await self.writer.wait_closed()
        except Exception as e:
            logger.error(f"关闭隧道连接失败: {e}")

