"""
性能分析器
"""
import time
import cProfile
import pstats
import io
from typing import Dict, Optional, Callable
from functools import wraps
from loguru import logger
from contextlib import contextmanager


class PerformanceProfiler:
    """性能分析器"""
    
    def __init__(self):
        self._profiler: Optional[cProfile.Profile] = None
        self._stats: Dict[str, float] = {}
    
    @contextmanager
    def profile(self, name: str = "default"):
        """性能分析上下文管理器"""
        profiler = cProfile.Profile()
        profiler.enable()
        
        start_time = time.time()
        
        try:
            yield
        finally:
            profiler.disable()
            duration = time.time() - start_time
            
            # 保存统计信息
            self._stats[name] = duration
            
            # 生成报告
            s = io.StringIO()
            ps = pstats.Stats(profiler, stream=s)
            ps.sort_stats('cumulative')
            ps.print_stats(20)  # 打印前20个最耗时的函数
            
            logger.debug(f"性能分析 [{name}]:\n{s.getvalue()}")
    
    def profile_function(self, name: Optional[str] = None):
        """函数性能分析装饰器"""
        def decorator(func: Callable):
            func_name = name or func.__name__
            
            @wraps(func)
            async def async_wrapper(*args, **kwargs):
                start_time = time.time()
                try:
                    result = await func(*args, **kwargs)
                    return result
                finally:
                    duration = time.time() - start_time
                    self._stats[f"{func_name}_async"] = duration
                    logger.debug(f"函数 {func_name} 执行时间: {duration:.4f}s")
            
            @wraps(func)
            def sync_wrapper(*args, **kwargs):
                start_time = time.time()
                try:
                    result = func(*args, **kwargs)
                    return result
                finally:
                    duration = time.time() - start_time
                    self._stats[f"{func_name}_sync"] = duration
                    logger.debug(f"函数 {func_name} 执行时间: {duration:.4f}s")
            
            import asyncio
            if asyncio.iscoroutinefunction(func):
                return async_wrapper
            else:
                return sync_wrapper
        
        return decorator
    
    def get_stats(self) -> Dict[str, float]:
        """获取统计信息"""
        return self._stats.copy()
    
    def get_slowest_functions(self, limit: int = 10) -> List[tuple]:
        """获取最慢的函数"""
        return sorted(self._stats.items(), key=lambda x: x[1], reverse=True)[:limit]
    
    def reset_stats(self):
        """重置统计"""
        self._stats.clear()


class RequestProfiler:
    """请求性能分析器"""
    
    def __init__(self):
        self._request_times: Dict[str, list] = {}
    
    def record_request(self, endpoint: str, duration: float):
        """记录请求耗时"""
        if endpoint not in self._request_times:
            self._request_times[endpoint] = []
        
        self._request_times[endpoint].append(duration)
        
        # 只保留最近1000次请求
        if len(self._request_times[endpoint]) > 1000:
            self._request_times[endpoint] = self._request_times[endpoint][-1000:]
    
    def get_endpoint_stats(self, endpoint: str) -> Optional[Dict]:
        """获取端点统计"""
        if endpoint not in self._request_times or not self._request_times[endpoint]:
            return None
        
        times = self._request_times[endpoint]
        
        return {
            'endpoint': endpoint,
            'count': len(times),
            'avg': sum(times) / len(times),
            'min': min(times),
            'max': max(times),
            'p95': self._percentile(times, 95),
            'p99': self._percentile(times, 99),
        }
    
    def get_all_stats(self) -> Dict[str, Dict]:
        """获取所有端点统计"""
        return {
            endpoint: self.get_endpoint_stats(endpoint)
            for endpoint in self._request_times.keys()
        }
    
    def _percentile(self, data: list, percentile: int) -> float:
        """计算百分位数"""
        sorted_data = sorted(data)
        index = int(len(sorted_data) * percentile / 100)
        return sorted_data[min(index, len(sorted_data) - 1)]

