"""
服务注册和发现 - 微服务架构支持
"""
import asyncio
import aiohttp
from typing import Dict, List, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from loguru import logger
import json


class ServiceStatus(str, Enum):
    """服务状态"""
    UP = "up"
    DOWN = "down"
    STARTING = "starting"
    STOPPING = "stopping"
    UNKNOWN = "unknown"


@dataclass
class Service:
    """服务"""
    service_id: str
    name: str
    version: str
    address: str
    port: int
    status: ServiceStatus = ServiceStatus.UNKNOWN
    health_endpoint: str = "/health"
    metadata: Dict = None
    registered_at: datetime = None
    last_heartbeat: datetime = None
    
    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}
        if self.registered_at is None:
            self.registered_at = datetime.now()
        if self.last_heartbeat is None:
            self.last_heartbeat = datetime.now()
    
    @property
    def url(self) -> str:
        """获取服务URL"""
        return f"http://{self.address}:{self.port}"


class ServiceRegistry:
    """服务注册中心"""
    
    def __init__(self):
        self.services: Dict[str, Service] = {}
        self.service_instances: Dict[str, List[str]] = {}  # service_name -> [service_ids]
        self._running = False
        self._heartbeat_task: Optional[asyncio.Task] = None
        self.heartbeat_interval = 30  # 心跳间隔（秒）
        self.service_timeout = 90  # 服务超时（秒）
    
    def register_service(self, service: Service) -> bool:
        """注册服务"""
        self.services[service.service_id] = service
        
        if service.name not in self.service_instances:
            self.service_instances[service.name] = []
        
        if service.service_id not in self.service_instances[service.name]:
            self.service_instances[service.name].append(service.service_id)
        
        logger.info(f"服务已注册: {service.name} ({service.service_id}) at {service.url}")
        return True
    
    def deregister_service(self, service_id: str) -> bool:
        """注销服务"""
        if service_id not in self.services:
            return False
        
        service = self.services[service_id]
        
        # 从实例列表中移除
        if service.name in self.service_instances:
            if service_id in self.service_instances[service.name]:
                self.service_instances[service.name].remove(service_id)
        
        del self.services[service_id]
        logger.info(f"服务已注销: {service_id}")
        return True
    
    def discover_service(self, service_name: str, version: Optional[str] = None) -> List[Service]:
        """发现服务"""
        if service_name not in self.service_instances:
            return []
        
        service_ids = self.service_instances[service_name]
        services = [
            self.services[sid] for sid in service_ids
            if sid in self.services
        ]
        
        # 过滤版本
        if version:
            services = [s for s in services if s.version == version]
        
        # 只返回健康的服务
        services = [s for s in services if s.status == ServiceStatus.UP]
        
        return services
    
    def get_service(self, service_id: str) -> Optional[Service]:
        """获取服务"""
        return self.services.get(service_id)
    
    async def check_service_health(self, service_id: str) -> bool:
        """检查服务健康"""
        service = self.services.get(service_id)
        if not service:
            return False
        
        try:
            url = f"{service.url}{service.health_endpoint}"
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
                async with session.get(url) as resp:
                    if resp.status == 200:
                        service.status = ServiceStatus.UP
                        service.last_heartbeat = datetime.now()
                        return True
                    else:
                        service.status = ServiceStatus.DOWN
                        return False
        except Exception as e:
            logger.debug(f"服务健康检查失败: {service_id} - {e}")
            service.status = ServiceStatus.DOWN
            return False
    
    async def heartbeat_loop(self):
        """心跳循环"""
        while self._running:
            try:
                for service_id in list(self.services.keys()):
                    await self.check_service_health(service_id)
                    
                    # 检查超时
                    service = self.services[service_id]
                    if service.last_heartbeat:
                        elapsed = (datetime.now() - service.last_heartbeat).total_seconds()
                        if elapsed > self.service_timeout:
                            logger.warning(f"服务超时: {service_id}")
                            service.status = ServiceStatus.DOWN
                
                await asyncio.sleep(self.heartbeat_interval)
            except Exception as e:
                logger.error(f"心跳循环错误: {e}")
                await asyncio.sleep(self.heartbeat_interval)
    
    async def start(self):
        """启动服务注册中心"""
        self._running = True
        self._heartbeat_task = asyncio.create_task(self.heartbeat_loop())
        logger.info("服务注册中心已启动")
    
    async def stop(self):
        """停止服务注册中心"""
        self._running = False
        
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass
    
    def get_all_services(self) -> List[Service]:
        """获取所有服务"""
        return list(self.services.values())
    
    def get_service_stats(self) -> Dict:
        """获取服务统计"""
        return {
            'total_services': len(self.services),
            'by_status': {
                status.value: len([s for s in self.services.values() if s.status == status])
                for status in ServiceStatus
            },
            'by_name': {
                name: len(instances)
                for name, instances in self.service_instances.items()
            }
        }


class ServiceDiscovery:
    """服务发现客户端"""
    
    def __init__(self, registry: ServiceRegistry):
        self.registry = registry
        self.load_balancer = "round_robin"  # round_robin, random, least_connections
        self._current_index: Dict[str, int] = {}
    
    def discover(self, service_name: str, version: Optional[str] = None) -> Optional[Service]:
        """发现服务（带负载均衡）"""
        services = self.registry.discover_service(service_name, version)
        
        if not services:
            return None
        
        if self.load_balancer == "round_robin":
            if service_name not in self._current_index:
                self._current_index[service_name] = 0
            
            index = self._current_index[service_name] % len(services)
            self._current_index[service_name] += 1
            
            return services[index]
        
        elif self.load_balancer == "random":
            import random
            return random.choice(services)
        
        elif self.load_balancer == "least_connections":
            # 简化实现，实际应该获取连接数
            return services[0]
        
        return services[0]
    
    async def call_service(self, service_name: str, endpoint: str, 
                          method: str = "GET", data: Dict = None) -> Optional[Dict]:
        """调用服务"""
        service = self.discover(service_name)
        
        if not service:
            logger.error(f"未找到服务: {service_name}")
            return None
        
        try:
            url = f"{service.url}{endpoint}"
            async with aiohttp.ClientSession() as session:
                if method == "GET":
                    async with session.get(url) as resp:
                        return await resp.json()
                elif method == "POST":
                    async with session.post(url, json=data) as resp:
                        return await resp.json()
                elif method == "PUT":
                    async with session.put(url, json=data) as resp:
                        return await resp.json()
                elif method == "DELETE":
                    async with session.delete(url) as resp:
                        return await resp.json()
        except Exception as e:
            logger.error(f"调用服务失败: {service_name} - {e}")
            return None

