"""
异步任务调度器 - 增强异步处理能力
"""
import asyncio
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from loguru import logger
import heapq


class TaskPriority(int, Enum):
    """任务优先级"""
    LOW = 3
    NORMAL = 2
    HIGH = 1
    CRITICAL = 0


@dataclass
class ScheduledTask:
    """调度任务"""
    task_id: str
    func: Callable
    args: tuple = ()
    kwargs: Dict = {}
    priority: TaskPriority = TaskPriority.NORMAL
    scheduled_time: datetime = None
    max_retries: int = 3
    retry_delay: float = 1.0
    
    def __lt__(self, other):
        """用于优先级队列排序"""
        if self.scheduled_time != other.scheduled_time:
            return self.scheduled_time < other.scheduled_time
        return self.priority.value < other.priority.value


class AsyncScheduler:
    """异步任务调度器"""
    
    def __init__(self, max_workers: int = 10):
        self.max_workers = max_workers
        self.task_queue: List[ScheduledTask] = []
        self.running_tasks: Dict[str, asyncio.Task] = {}
        self.completed_tasks: List[Dict] = []
        self._running = False
        self._workers: List[asyncio.Task] = []
        self._scheduler_task: Optional[asyncio.Task] = None
    
    async def schedule_task(self, func: Callable, args: tuple = (), kwargs: Dict = {},
                          priority: TaskPriority = TaskPriority.NORMAL,
                          delay: float = 0, max_retries: int = 3) -> str:
        """调度任务"""
        import uuid
        task_id = str(uuid.uuid4())
        
        scheduled_time = datetime.now() + timedelta(seconds=delay)
        
        task = ScheduledTask(
            task_id=task_id,
            func=func,
            args=args,
            kwargs=kwargs,
            priority=priority,
            scheduled_time=scheduled_time,
            max_retries=max_retries
        )
        
        heapq.heappush(self.task_queue, task)
        logger.info(f"任务已调度: {task_id} (优先级: {priority.name}, 延迟: {delay}秒)")
        
        return task_id
    
    async def schedule_periodic(self, func: Callable, interval: float,
                               args: tuple = (), kwargs: Dict = {},
                               priority: TaskPriority = TaskPriority.NORMAL) -> str:
        """调度周期性任务"""
        import uuid
        task_id = str(uuid.uuid4())
        
        async def periodic_wrapper():
            while self._running:
                try:
                    if asyncio.iscoroutinefunction(func):
                        await func(*args, **kwargs)
                    else:
                        func(*args, **kwargs)
                except Exception as e:
                    logger.error(f"周期性任务执行失败: {task_id} - {e}")
                
                await asyncio.sleep(interval)
        
        task = asyncio.create_task(periodic_wrapper())
        self.running_tasks[task_id] = task
        
        logger.info(f"周期性任务已启动: {task_id} (间隔: {interval}秒)")
        return task_id
    
    async def execute_task(self, task: ScheduledTask):
        """执行任务"""
        retry_count = 0
        
        while retry_count <= task.max_retries:
            try:
                if asyncio.iscoroutinefunction(task.func):
                    result = await task.func(*task.args, **task.kwargs)
                else:
                    result = task.func(*task.args, **task.kwargs)
                
                self.completed_tasks.append({
                    'task_id': task.task_id,
                    'status': 'completed',
                    'completed_at': datetime.now(),
                    'result': str(result)[:100]  # 限制长度
                })
                
                logger.info(f"任务执行成功: {task.task_id}")
                return result
                
            except Exception as e:
                retry_count += 1
                if retry_count <= task.max_retries:
                    logger.warning(f"任务执行失败，重试中: {task.task_id} ({retry_count}/{task.max_retries})")
                    await asyncio.sleep(task.retry_delay * retry_count)  # 指数退避
                else:
                    self.completed_tasks.append({
                        'task_id': task.task_id,
                        'status': 'failed',
                        'completed_at': datetime.now(),
                        'error': str(e)
                    })
                    logger.error(f"任务执行失败: {task.task_id} - {e}")
                    raise
    
    async def worker(self, worker_id: int):
        """工作线程"""
        logger.info(f"Worker {worker_id} 启动")
        
        while self._running:
            try:
                # 检查是否有可执行的任务
                if not self.task_queue:
                    await asyncio.sleep(0.1)
                    continue
                
                # 获取最早的任务
                now = datetime.now()
                if self.task_queue[0].scheduled_time <= now:
                    task = heapq.heappop(self.task_queue)
                    await self.execute_task(task)
                else:
                    await asyncio.sleep(0.1)
                    
            except Exception as e:
                logger.error(f"Worker {worker_id} 错误: {e}")
                await asyncio.sleep(1)
        
        logger.info(f"Worker {worker_id} 停止")
    
    async def start(self):
        """启动调度器"""
        self._running = True
        
        # 启动工作线程
        for i in range(self.max_workers):
            worker_task = asyncio.create_task(self.worker(i))
            self._workers.append(worker_task)
        
        logger.info(f"异步调度器已启动 ({self.max_workers}个工作线程)")
    
    async def stop(self):
        """停止调度器"""
        self._running = False
        
        # 停止所有工作线程
        if self._workers:
            for worker in self._workers:
                worker.cancel()
            await asyncio.gather(*self._workers, return_exceptions=True)
            self._workers.clear()
        
        # 停止所有运行中的任务
        for task_id, task in list(self.running_tasks.items()):
            task.cancel()
        
        logger.info("异步调度器已停止")
    
    def get_stats(self) -> Dict:
        """获取统计信息"""
        return {
            'queued_tasks': len(self.task_queue),
            'running_tasks': len(self.running_tasks),
            'completed_tasks': len(self.completed_tasks),
            'workers': len(self._workers),
        }


class BatchProcessor:
    """批处理器 - 批量异步处理"""
    
    def __init__(self, batch_size: int = 100, flush_interval: float = 5.0):
        self.batch_size = batch_size
        self.flush_interval = flush_interval
        self.batch: List[Dict] = []
        self.processor: Optional[Callable] = None
        self._running = False
        self._flush_task: Optional[asyncio.Task] = None
    
    def set_processor(self, processor: Callable):
        """设置批处理器函数"""
        self.processor = processor
    
    async def add_item(self, item: Dict):
        """添加项目到批次"""
        self.batch.append(item)
        
        if len(self.batch) >= self.batch_size:
            await self.flush()
    
    async def flush(self):
        """刷新批次"""
        if not self.batch or not self.processor:
            return
        
        batch_to_process = self.batch.copy()
        self.batch.clear()
        
        try:
            if asyncio.iscoroutinefunction(self.processor):
                await self.processor(batch_to_process)
            else:
                self.processor(batch_to_process)
            
            logger.info(f"批次处理完成: {len(batch_to_process)}个项目")
        except Exception as e:
            logger.error(f"批次处理失败: {e}")
    
    async def _auto_flush_loop(self):
        """自动刷新循环"""
        while self._running:
            await asyncio.sleep(self.flush_interval)
            await self.flush()
    
    async def start(self):
        """启动批处理器"""
        self._running = True
        self._flush_task = asyncio.create_task(self._auto_flush_loop())
        logger.info("批处理器已启动")
    
    async def stop(self):
        """停止批处理器"""
        self._running = False
        
        if self._flush_task:
            self._flush_task.cancel()
            try:
                await self._flush_task
            except asyncio.CancelledError:
                pass
        
        # 处理剩余批次
        await self.flush()
        logger.info("批处理器已停止")

