"""
压力测试
"""
import asyncio
import time
import random
from typing import List, Dict
from loguru import logger
import aiohttp
import websockets


class StressTest:
    """压力测试"""
    
    def __init__(self, base_url: str = "http://localhost:8080"):
        self.base_url = base_url
        self.ws_url = base_url.replace("http", "ws")
        self.stats = {
            'total_requests': 0,
            'success_requests': 0,
            'failed_requests': 0,
            'total_time': 0,
            'errors': []
        }
    
    async def stress_test_api(self, endpoint: str, method: str = "GET",
                             duration: int = 60, rate: int = 100):
        """API压力测试"""
        logger.info(f"开始API压力测试: {method} {endpoint}")
        logger.info(f"持续时间: {duration}秒, 目标速率: {rate} QPS")
        
        start_time = time.time()
        request_count = 0
        success_count = 0
        error_count = 0
        
        async def make_request(session):
            nonlocal request_count, success_count, error_count
            
            try:
                request_count += 1
                req_start = time.time()
                
                if method == "GET":
                    async with session.get(f"{self.base_url}{endpoint}") as resp:
                        await resp.read()
                elif method == "POST":
                    async with session.post(f"{self.base_url}{endpoint}", json={}) as resp:
                        await resp.read()
                
                elapsed = time.time() - req_start
                if resp.status == 200:
                    success_count += 1
                else:
                    error_count += 1
                    self.stats['errors'].append(f"{method} {endpoint}: {resp.status}")
            except Exception as e:
                error_count += 1
                self.stats['errors'].append(f"{method} {endpoint}: {str(e)}")
        
        async with aiohttp.ClientSession() as session:
            while time.time() - start_time < duration:
                # 控制速率
                tasks = []
                for _ in range(rate):
                    tasks.append(make_request(session))
                    await asyncio.sleep(1.0 / rate)
                
                await asyncio.gather(*tasks, return_exceptions=True)
        
        elapsed = time.time() - start_time
        self.stats['total_requests'] += request_count
        self.stats['success_requests'] += success_count
        self.stats['failed_requests'] += error_count
        self.stats['total_time'] += elapsed
        
        logger.info(f"压力测试结果:")
        logger.info(f"  总请求数: {request_count}")
        logger.info(f"  成功: {success_count}")
        logger.info(f"  失败: {error_count}")
        logger.info(f"  成功率: {success_count/request_count*100:.2f}%")
        logger.info(f"  实际QPS: {request_count/elapsed:.2f}")
    
    async def stress_test_websocket(self, duration: int = 60, 
                                    connections: int = 100):
        """WebSocket压力测试"""
        logger.info(f"开始WebSocket压力测试")
        logger.info(f"持续时间: {duration}秒, 并发连接数: {connections}")
        
        start_time = time.time()
        connection_count = 0
        success_count = 0
        error_count = 0
        
        async def connect_websocket():
            nonlocal connection_count, success_count, error_count
            
            try:
                connection_count += 1
                async with websockets.connect(f"{self.ws_url}/ws/client") as ws:
                    success_count += 1
                    # 保持连接
                    await asyncio.sleep(duration)
                    await ws.close()
            except Exception as e:
                error_count += 1
                self.stats['errors'].append(f"WebSocket连接失败: {str(e)}")
        
        # 创建并发连接
        tasks = [connect_websocket() for _ in range(connections)]
        await asyncio.gather(*tasks, return_exceptions=True)
        
        elapsed = time.time() - start_time
        logger.info(f"WebSocket压力测试结果:")
        logger.info(f"  总连接数: {connection_count}")
        logger.info(f"  成功: {success_count}")
        logger.info(f"  失败: {error_count}")
        logger.info(f"  成功率: {success_count/connection_count*100:.2f}%")
    
    async def stress_test_mixed_workload(self, duration: int = 300):
        """混合工作负载压力测试"""
        logger.info(f"开始混合工作负载压力测试 (持续时间: {duration}秒)")
        
        async def api_worker():
            while True:
                await self.stress_test_api("/api/health", "GET", duration=10, rate=50)
                await asyncio.sleep(5)
        
        async def websocket_worker():
            while True:
                await self.stress_test_websocket(duration=30, connections=20)
                await asyncio.sleep(10)
        
        start_time = time.time()
        tasks = [
            api_worker(),
            websocket_worker(),
        ]
        
        try:
            await asyncio.wait_for(
                asyncio.gather(*tasks, return_exceptions=True),
                timeout=duration
            )
        except asyncio.TimeoutError:
            pass
        
        elapsed = time.time() - start_time
        logger.info(f"混合工作负载测试完成 (实际运行: {elapsed:.2f}秒)")
        logger.info(f"总统计:")
        logger.info(f"  总请求数: {self.stats['total_requests']}")
        logger.info(f"  成功: {self.stats['success_requests']}")
        logger.info(f"  失败: {self.stats['failed_requests']}")
        logger.info(f"  成功率: {self.stats['success_requests']/self.stats['total_requests']*100:.2f}%")
    
    def get_stats(self) -> Dict:
        """获取统计信息"""
        return self.stats.copy()


async def run_stress_test():
    """运行压力测试"""
    stress_test = StressTest()
    
    # API压力测试
    await stress_test.stress_test_api("/api/health", "GET", duration=60, rate=100)
    
    # WebSocket压力测试
    await stress_test.stress_test_websocket(duration=60, connections=50)
    
    # 混合工作负载
    # await stress_test.stress_test_mixed_workload(duration=300)


if __name__ == "__main__":
    asyncio.run(run_stress_test())

