"""
流量转发引擎
"""
import asyncio
import socket
from typing import Dict, Optional, Callable
from loguru import logger

from common.models.tunnel import Tunnel, TunnelType, TunnelStatus
from typing import TYPE_CHECKING
from .udp_proxy import UDPProxy

if TYPE_CHECKING:
    from .tunnel_manager import TunnelManager


class ForwardEngine:
    """流量转发引擎"""
    
    def __init__(self, tunnel_manager: 'TunnelManager'):
        self.tunnel_manager = tunnel_manager
        self._active_connections: Dict[str, Dict[str, asyncio.Task]] = {}  # tunnel_id -> {conn_id: task}
        self._connection_counter = 0
        self._udp_proxy = UDPProxy()
    
    async def handle_tunnel_connection(self, tunnel_id: str, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter):
        """处理隧道连接"""
        tunnel = self.tunnel_manager.get_tunnel(tunnel_id)
        if not tunnel:
            logger.error(f"隧道不存在: {tunnel_id}")
            client_writer.close()
            return
        
        if tunnel.status != TunnelStatus.ACTIVE:
            logger.warning(f"隧道未激活: {tunnel_id}")
            client_writer.close()
            return
        
        # 连接到本地服务
        try:
            if tunnel.tunnel_type == TunnelType.TCP:
                await self._handle_tcp_forward(tunnel, client_reader, client_writer)
            elif tunnel.tunnel_type == TunnelType.HTTP:
                await self._handle_http_forward(tunnel, client_reader, client_writer)
            elif tunnel.tunnel_type == TunnelType.UDP:
                await self._handle_udp_forward(tunnel, client_reader, client_writer)
            else:
                logger.error(f"不支持的隧道类型: {tunnel.tunnel_type}")
                client_writer.close()
        except Exception as e:
            logger.error(f"处理隧道连接失败: {e}")
            client_writer.close()
    
    async def _handle_tcp_forward(self, tunnel: Tunnel, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter):
        """TCP转发"""
        conn_id = f"{tunnel.id}_{self._connection_counter}"
        self._connection_counter += 1
        
        try:
            # 连接到本地服务
            local_reader, local_writer = await asyncio.open_connection(
                tunnel.local_host,
                tunnel.local_port
            )
            
            logger.debug(f"TCP转发连接建立: {conn_id} -> {tunnel.local_host}:{tunnel.local_port}")
            
            # 双向转发
            async def forward_client_to_local():
                try:
                    while True:
                        data = await client_reader.read(4096)
                        if not data:
                            break
                        local_writer.write(data)
                        await local_writer.drain()
                except Exception as e:
                    logger.debug(f"客户端到本地转发结束: {e}")
                finally:
                    local_writer.close()
            
            async def forward_local_to_client():
                try:
                    while True:
                        data = await local_reader.read(4096)
                        if not data:
                            break
                        client_writer.write(data)
                        await client_writer.drain()
                except Exception as e:
                    logger.debug(f"本地到客户端转发结束: {e}")
                finally:
                    client_writer.close()
            
            # 启动双向转发
            await asyncio.gather(
                forward_client_to_local(),
                forward_local_to_client(),
                return_exceptions=True
            )
            
        except Exception as e:
            logger.error(f"TCP转发失败: {e}")
        finally:
            try:
                client_writer.close()
                await client_writer.wait_closed()
            except:
                pass
    
    async def _handle_http_forward(self, tunnel: Tunnel, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter):
        """HTTP转发（反向代理）"""
        try:
            # 读取HTTP请求
            request_data = b""
            while True:
                chunk = await client_reader.read(4096)
                if not chunk:
                    break
                request_data += chunk
                # 检查是否收到完整的HTTP请求头
                if b"\r\n\r\n" in request_data:
                    break
            
            if not request_data:
                client_writer.close()
                return
            
            # 连接到本地HTTP服务
            local_reader, local_writer = await asyncio.open_connection(
                tunnel.local_host,
                tunnel.local_port
            )
            
            # 转发请求到本地服务
            local_writer.write(request_data)
            await local_writer.drain()
            
            # 读取响应并转发给客户端
            while True:
                data = await local_reader.read(4096)
                if not data:
                    break
                client_writer.write(data)
                await client_writer.drain()
            
            local_writer.close()
            client_writer.close()
            
        except Exception as e:
            logger.error(f"HTTP转发失败: {e}")
            try:
                client_writer.close()
            except:
                pass
    
    async def _handle_udp_forward(self, tunnel: Tunnel, client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter):
        """UDP转发"""
        # UDP代理需要单独启动，这里只是处理WebSocket连接
        # 实际的UDP转发在UDPProxy中处理
        logger.info(f"UDP隧道连接: {tunnel.id}")
        # 对于UDP，我们需要启动UDP代理服务器
        if tunnel.remote_port:
            await self._udp_proxy.start_udp_proxy(tunnel, '0.0.0.0', tunnel.remote_port)

