"""
服务器选择器 - 智能选择最优服务器
"""
import asyncio
import time
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
import aiohttp


@dataclass
class ServerInfo:
    """服务器信息"""
    url: str
    name: str = ""
    region: str = ""
    priority: int = 0  # 优先级，数字越大优先级越高
    latency: float = 0.0  # 延迟（毫秒）
    last_check: float = 0.0  # 最后检查时间
    status: str = "unknown"  # unknown, online, offline, error
    error_count: int = 0  # 错误次数


class ServerSelector:
    """服务器选择器"""
    
    def __init__(self, servers: List[str]):
        """
        初始化服务器选择器
        
        Args:
            servers: 服务器URL列表，格式: "http://server1:8080" 或 "http://server1:8080|name|region|priority"
        """
        self.servers: List[ServerInfo] = []
        self._current_server: Optional[ServerInfo] = None
        self._check_interval = 60  # 检查间隔（秒）
        self._last_check_time = 0.0
        
        # 解析服务器列表
        for server_str in servers:
            parts = server_str.split('|')
            url = parts[0].strip()
            name = parts[1].strip() if len(parts) > 1 else ""
            region = parts[2].strip() if len(parts) > 2 else ""
            priority = int(parts[3].strip()) if len(parts) > 3 and parts[3].strip().isdigit() else 0
            
            self.servers.append(ServerInfo(
                url=url,
                name=name,
                region=region,
                priority=priority,
            ))
        
        if not self.servers:
            raise ValueError("至少需要一个服务器")
    
    async def select_best_server(self) -> Optional[ServerInfo]:
        """选择最优服务器"""
        current_time = time.time()
        
        # 如果距离上次检查时间太短，直接返回当前服务器
        if self._current_server and (current_time - self._last_check_time) < self._check_interval:
            return self._current_server
        
        # 检查所有服务器
        await self._check_all_servers()
        
        # 选择最优服务器
        best_server = self._find_best_server()
        
        if best_server:
            self._current_server = best_server
            self._last_check_time = current_time
            logger.info(f"选择服务器: {best_server.url} (延迟: {best_server.latency:.2f}ms)")
        
        return best_server
    
    async def _check_all_servers(self):
        """检查所有服务器"""
        tasks = [self._check_server(server) for server in self.servers]
        await asyncio.gather(*tasks, return_exceptions=True)
    
    async def _check_server(self, server: ServerInfo):
        """检查单个服务器"""
        try:
            start_time = time.time()
            
            # 发送健康检查请求
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
                try:
                    async with session.get(f"{server.url}/api/health") as response:
                        if response.status == 200:
                            latency = (time.time() - start_time) * 1000  # 转换为毫秒
                            server.latency = latency
                            server.status = "online"
                            server.error_count = 0
                            server.last_check = time.time()
                        else:
                            server.status = "error"
                            server.error_count += 1
                except asyncio.TimeoutError:
                    server.status = "offline"
                    server.latency = 9999  # 超时设置为很大的值
                    server.error_count += 1
                except Exception as e:
                    logger.debug(f"检查服务器失败 {server.url}: {e}")
                    server.status = "error"
                    server.error_count += 1
                    
        except Exception as e:
            logger.error(f"检查服务器异常 {server.url}: {e}")
            server.status = "error"
            server.error_count += 1
    
    def _find_best_server(self) -> Optional[ServerInfo]:
        """找到最优服务器"""
        # 过滤掉离线的服务器
        online_servers = [s for s in self.servers if s.status == "online"]
        
        if not online_servers:
            # 如果没有在线服务器，选择错误次数最少的
            online_servers = sorted(self.servers, key=lambda s: s.error_count)
            if online_servers:
                return online_servers[0]
            return None
        
        # 按优先级和延迟排序
        # 优先级高的优先，相同优先级则选择延迟低的
        online_servers.sort(key=lambda s: (-s.priority, s.latency))
        
        return online_servers[0]
    
    def get_current_server(self) -> Optional[ServerInfo]:
        """获取当前服务器"""
        return self._current_server
    
    def get_all_servers_status(self) -> List[Dict]:
        """获取所有服务器状态"""
        return [
            {
                'url': s.url,
                'name': s.name,
                'region': s.region,
                'priority': s.priority,
                'latency': s.latency,
                'status': s.status,
                'error_count': s.error_count,
                'last_check': s.last_check,
            }
            for s in self.servers
        ]
    
    async def test_connectivity(self, server_url: str) -> Dict:
        """测试服务器连通性"""
        result = {
            'url': server_url,
            'success': False,
            'latency': 0.0,
            'error': None,
        }
        
        try:
            start_time = time.time()
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
                async with session.get(f"{server_url}/api/health") as response:
                    if response.status == 200:
                        latency = (time.time() - start_time) * 1000
                        result['success'] = True
                        result['latency'] = latency
                    else:
                        result['error'] = f"HTTP {response.status}"
        except asyncio.TimeoutError:
            result['error'] = "连接超时"
        except Exception as e:
            result['error'] = str(e)
        
        return result

