"""
存储适配器 - 统一接口，支持内存和数据库存储
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
from datetime import datetime

from common.models.client import Client, ClientStatus
from common.models.tunnel import Tunnel, TunnelType, TunnelStatus
from common.models.domain import Domain, DomainStatus
from common.models.session import Session, SessionStatus


class StorageAdapter(ABC):
    """存储适配器抽象基类"""
    
    @abstractmethod
    async def create_client(self, name: str, token: str, **kwargs) -> Client:
        """创建客户端"""
        pass
    
    @abstractmethod
    async def get_client(self, client_id: str) -> Optional[Client]:
        """获取客户端"""
        pass
    
    @abstractmethod
    async def get_client_by_token(self, token: str) -> Optional[Client]:
        """根据token获取客户端"""
        pass
    
    @abstractmethod
    async def list_clients(self) -> List[Client]:
        """列出所有客户端"""
        pass
    
    @abstractmethod
    async def update_client_status(self, client_id: str, status: ClientStatus, ip_address: str = None):
        """更新客户端状态"""
        pass
    
    @abstractmethod
    async def create_tunnel(self, client_id: str, name: str, tunnel_type: TunnelType,
                           local_host: str, local_port: int, **kwargs) -> Tunnel:
        """创建隧道"""
        pass
    
    @abstractmethod
    async def get_tunnel(self, tunnel_id: str) -> Optional[Tunnel]:
        """获取隧道"""
        pass
    
    @abstractmethod
    async def list_client_tunnels(self, client_id: str) -> List[Tunnel]:
        """列出客户端的所有隧道"""
        pass
    
    @abstractmethod
    async def update_tunnel_status(self, tunnel_id: str, status: TunnelStatus):
        """更新隧道状态"""
        pass
    
    @abstractmethod
    async def delete_tunnel(self, tunnel_id: str):
        """删除隧道"""
        pass


class MemoryStorageAdapter(StorageAdapter):
    """内存存储适配器"""
    
    def __init__(self):
        self._clients: Dict[str, Client] = {}
        self._tunnels: Dict[str, Tunnel] = {}
        self._client_tunnels: Dict[str, List[str]] = {}
    
    async def create_client(self, name: str, token: str, **kwargs) -> Client:
        import uuid
        client = Client(
            id=str(uuid.uuid4()),
            name=name,
            token=token,
            status=ClientStatus.OFFLINE,
            created_at=datetime.now(),
            updated_at=datetime.now(),
            **kwargs
        )
        self._clients[client.id] = client
        return client
    
    async def get_client(self, client_id: str) -> Optional[Client]:
        return self._clients.get(client_id)
    
    async def get_client_by_token(self, token: str) -> Optional[Client]:
        for client in self._clients.values():
            if client.token == token:
                return client
        return None
    
    async def list_clients(self) -> List[Client]:
        return list(self._clients.values())
    
    async def update_client_status(self, client_id: str, status: ClientStatus, ip_address: str = None):
        client = self._clients.get(client_id)
        if client:
            client.status = status
            client.last_seen = datetime.now()
            client.updated_at = datetime.now()
            if ip_address:
                client.ip_address = ip_address
    
    async def create_tunnel(self, client_id: str, name: str, tunnel_type: TunnelType,
                           local_host: str, local_port: int, **kwargs) -> Tunnel:
        import uuid
        tunnel = Tunnel(
            id=str(uuid.uuid4()),
            client_id=client_id,
            name=name,
            tunnel_type=tunnel_type,
            local_host=local_host,
            local_port=local_port,
            status=TunnelStatus.INACTIVE,
            created_at=datetime.now(),
            updated_at=datetime.now(),
            **kwargs
        )
        self._tunnels[tunnel.id] = tunnel
        if client_id not in self._client_tunnels:
            self._client_tunnels[client_id] = []
        self._client_tunnels[client_id].append(tunnel.id)
        return tunnel
    
    async def get_tunnel(self, tunnel_id: str) -> Optional[Tunnel]:
        return self._tunnels.get(tunnel_id)
    
    async def list_client_tunnels(self, client_id: str) -> List[Tunnel]:
        tunnel_ids = self._client_tunnels.get(client_id, [])
        return [self._tunnels[tid] for tid in tunnel_ids if tid in self._tunnels]
    
    async def update_tunnel_status(self, tunnel_id: str, status: TunnelStatus):
        tunnel = self._tunnels.get(tunnel_id)
        if tunnel:
            tunnel.status = status
            tunnel.updated_at = datetime.now()
    
    async def delete_tunnel(self, tunnel_id: str):
        tunnel = self._tunnels.get(tunnel_id)
        if tunnel:
            if tunnel.client_id in self._client_tunnels:
                self._client_tunnels[tunnel.client_id].remove(tunnel_id)
            del self._tunnels[tunnel_id]


class DatabaseStorageAdapter(StorageAdapter):
    """数据库存储适配器"""
    
    def __init__(self, db_session):
        from server.db.repositories import (
            ClientRepository, TunnelRepository
        )
        self.db = db_session
        self.client_repo = ClientRepository
        self.tunnel_repo = TunnelRepository
    
    async def create_client(self, name: str, token: str, **kwargs) -> Client:
        db_client = self.client_repo.create(self.db, name=name, token=token, **kwargs)
        return self._db_client_to_model(db_client)
    
    async def get_client(self, client_id: str) -> Optional[Client]:
        db_client = self.client_repo.get_by_id(self.db, client_id)
        return self._db_client_to_model(db_client) if db_client else None
    
    async def get_client_by_token(self, token: str) -> Optional[Client]:
        db_client = self.client_repo.get_by_token(self.db, token)
        return self._db_client_to_model(db_client) if db_client else None
    
    async def list_clients(self) -> List[Client]:
        db_clients = self.client_repo.get_all(self.db)
        return [self._db_client_to_model(c) for c in db_clients]
    
    async def update_client_status(self, client_id: str, status: ClientStatus, ip_address: str = None):
        self.client_repo.update_status(self.db, client_id, status, ip_address)
    
    async def create_tunnel(self, client_id: str, name: str, tunnel_type: TunnelType,
                           local_host: str, local_port: int, **kwargs) -> Tunnel:
        db_tunnel = self.tunnel_repo.create(
            self.db, client_id, name, tunnel_type, local_host, local_port, **kwargs
        )
        return self._db_tunnel_to_model(db_tunnel)
    
    async def get_tunnel(self, tunnel_id: str) -> Optional[Tunnel]:
        db_tunnel = self.tunnel_repo.get_by_id(self.db, tunnel_id)
        return self._db_tunnel_to_model(db_tunnel) if db_tunnel else None
    
    async def list_client_tunnels(self, client_id: str) -> List[Tunnel]:
        db_tunnels = self.tunnel_repo.get_by_client(self.db, client_id)
        return [self._db_tunnel_to_model(t) for t in db_tunnels]
    
    async def update_tunnel_status(self, tunnel_id: str, status: TunnelStatus):
        self.tunnel_repo.update_status(self.db, tunnel_id, status)
    
    async def delete_tunnel(self, tunnel_id: str):
        self.tunnel_repo.delete(self.db, tunnel_id)
    
    def _db_client_to_model(self, db_client) -> Client:
        """将数据库模型转换为领域模型"""
        return Client(
            id=db_client.id,
            name=db_client.name,
            token=db_client.token,
            status=db_client.status,
            created_at=db_client.created_at,
            updated_at=db_client.updated_at,
            last_seen=db_client.last_seen,
            ip_address=db_client.ip_address,
            version=db_client.version,
            description=db_client.description,
        )
    
    def _db_tunnel_to_model(self, db_tunnel) -> Tunnel:
        """将数据库模型转换为领域模型"""
        return Tunnel(
            id=db_tunnel.id,
            client_id=db_tunnel.client_id,
            name=db_tunnel.name,
            tunnel_type=db_tunnel.tunnel_type,
            local_host=db_tunnel.local_host,
            local_port=db_tunnel.local_port,
            remote_port=db_tunnel.remote_port,
            domain=db_tunnel.domain,
            status=db_tunnel.status,
            created_at=db_tunnel.created_at,
            updated_at=db_tunnel.updated_at,
            description=db_tunnel.description,
        )

