"""
隧道管理器测试
"""
import pytest
from server.core.tunnel.tunnel_manager import TunnelManager
from common.models.tunnel import TunnelType, TunnelStatus


@pytest.fixture
def tunnel_manager():
    """创建隧道管理器实例"""
    return TunnelManager()


def test_create_tunnel(tunnel_manager):
    """测试创建隧道"""
    tunnel = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Test Tunnel",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=22,
    )
    
    assert tunnel is not None
    assert tunnel.client_id == "test_client"
    assert tunnel.name == "Test Tunnel"
    assert tunnel.tunnel_type == TunnelType.TCP
    assert tunnel.local_port == 22
    assert tunnel.remote_port is not None
    assert tunnel.status == TunnelStatus.INACTIVE


def test_get_tunnel(tunnel_manager):
    """测试获取隧道"""
    tunnel = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Test Tunnel",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=22,
    )
    
    retrieved = tunnel_manager.get_tunnel(tunnel.id)
    assert retrieved is not None
    assert retrieved.id == tunnel.id


def test_get_client_tunnels(tunnel_manager):
    """测试获取客户端的所有隧道"""
    # 创建多个隧道
    tunnel1 = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Tunnel 1",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=22,
    )
    
    tunnel2 = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Tunnel 2",
        tunnel_type=TunnelType.HTTP,
        local_host="127.0.0.1",
        local_port=80,
    )
    
    # 为另一个客户端创建隧道
    tunnel_manager.create_tunnel(
        client_id="other_client",
        name="Tunnel 3",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=3306,
    )
    
    client_tunnels = tunnel_manager.get_client_tunnels("test_client")
    assert len(client_tunnels) == 2
    assert tunnel1.id in [t.id for t in client_tunnels]
    assert tunnel2.id in [t.id for t in client_tunnels]


def test_remove_tunnel(tunnel_manager):
    """测试删除隧道"""
    tunnel = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Test Tunnel",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=22,
    )
    
    remote_port = tunnel.remote_port
    tunnel_manager.remove_tunnel(tunnel.id)
    
    assert tunnel_manager.get_tunnel(tunnel.id) is None
    assert remote_port not in tunnel_manager._used_ports


@pytest.mark.asyncio
async def test_activate_tunnel(tunnel_manager):
    """测试激活隧道"""
    tunnel = tunnel_manager.create_tunnel(
        client_id="test_client",
        name="Test Tunnel",
        tunnel_type=TunnelType.TCP,
        local_host="127.0.0.1",
        local_port=22,
    )
    
    await tunnel_manager.activate_tunnel(tunnel.id)
    
    updated = tunnel_manager.get_tunnel(tunnel.id)
    assert updated.status == TunnelStatus.ACTIVE

