"""
任务队列 - 异步任务处理
"""
import asyncio
import json
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass, asdict
from datetime import datetime
from enum import Enum
from loguru import logger
import redis.asyncio as aioredis


class TaskStatus(str, Enum):
    """任务状态"""
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    RETRYING = "retrying"


@dataclass
class Task:
    """任务"""
    task_id: str
    task_type: str
    payload: Dict
    status: TaskStatus = TaskStatus.PENDING
    created_at: datetime = None
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    retry_count: int = 0
    max_retries: int = 3
    error: Optional[str] = None
    
    def __post_init__(self):
        if self.created_at is None:
            self.created_at = datetime.now()


class TaskQueue:
    """任务队列"""
    
    def __init__(self, redis_url: str = "redis://localhost:6379/1", queue_name: str = "nps_tasks"):
        self.redis_url = redis_url
        self.queue_name = queue_name
        self.redis: Optional[aioredis.Redis] = None
        self.workers: List[asyncio.Task] = []
        self.handlers: Dict[str, Callable] = {}
        self._running = False
    
    async def connect(self):
        """连接Redis"""
        try:
            self.redis = await aioredis.from_url(self.redis_url)
            await self.redis.ping()
            logger.info("任务队列Redis连接成功")
        except Exception as e:
            logger.error(f"任务队列Redis连接失败: {e}")
            raise
    
    async def disconnect(self):
        """断开连接"""
        self._running = False
        
        # 等待所有worker完成
        if self.workers:
            await asyncio.gather(*self.workers, return_exceptions=True)
        
        if self.redis:
            await self.redis.close()
    
    def register_handler(self, task_type: str, handler: Callable):
        """注册任务处理器"""
        self.handlers[task_type] = handler
        logger.info(f"注册任务处理器: {task_type}")
    
    async def enqueue(self, task_type: str, payload: Dict, 
                     max_retries: int = 3) -> str:
        """入队任务"""
        import uuid
        task_id = str(uuid.uuid4())
        
        task = Task(
            task_id=task_id,
            task_type=task_type,
            payload=payload,
            max_retries=max_retries
        )
        
        # 序列化任务
        task_data = json.dumps({
            **asdict(task),
            'created_at': task.created_at.isoformat(),
        }, default=str)
        
        # 推入队列
        await self.redis.lpush(f"{self.queue_name}:pending", task_data)
        
        # 保存任务状态
        await self.redis.setex(
            f"{self.queue_name}:task:{task_id}",
            3600,  # 1小时过期
            task_data
        )
        
        logger.info(f"任务已入队: {task_id} ({task_type})")
        return task_id
    
    async def dequeue(self) -> Optional[Task]:
        """出队任务"""
        task_data = await self.redis.brpop(
            f"{self.queue_name}:pending",
            timeout=1
        )
        
        if task_data:
            _, data = task_data
            task_dict = json.loads(data)
            
            task = Task(
                task_id=task_dict['task_id'],
                task_type=task_dict['task_type'],
                payload=task_dict['payload'],
                status=TaskStatus(task_dict.get('status', 'pending')),
                retry_count=task_dict.get('retry_count', 0),
                max_retries=task_dict.get('max_retries', 3),
            )
            
            if task_dict.get('created_at'):
                task.created_at = datetime.fromisoformat(task_dict['created_at'])
            
            return task
        
        return None
    
    async def process_task(self, task: Task):
        """处理任务"""
        task.status = TaskStatus.PROCESSING
        task.started_at = datetime.now()
        
        # 更新任务状态
        await self._update_task_status(task)
        
        try:
            # 获取处理器
            handler = self.handlers.get(task.task_type)
            if not handler:
                raise ValueError(f"未找到任务处理器: {task.task_type}")
            
            # 执行任务
            if asyncio.iscoroutinefunction(handler):
                result = await handler(task.payload)
            else:
                result = handler(task.payload)
            
            # 任务完成
            task.status = TaskStatus.COMPLETED
            task.completed_at = datetime.now()
            
            logger.info(f"任务完成: {task.task_id}")
            
        except Exception as e:
            # 任务失败
            task.error = str(e)
            task.retry_count += 1
            
            if task.retry_count < task.max_retries:
                task.status = TaskStatus.RETRYING
                logger.warning(f"任务重试: {task.task_id} (尝试 {task.retry_count}/{task.max_retries})")
                
                # 重新入队
                await asyncio.sleep(2 ** task.retry_count)  # 指数退避
                task_data = json.dumps({
                    **asdict(task),
                    'created_at': task.created_at.isoformat(),
                }, default=str)
                await self.redis.lpush(f"{self.queue_name}:pending", task_data)
            else:
                task.status = TaskStatus.FAILED
                logger.error(f"任务失败: {task.task_id} - {e}")
        
        finally:
            await self._update_task_status(task)
    
    async def _update_task_status(self, task: Task):
        """更新任务状态"""
        task_data = json.dumps({
            **asdict(task),
            'created_at': task.created_at.isoformat() if task.created_at else None,
            'started_at': task.started_at.isoformat() if task.started_at else None,
            'completed_at': task.completed_at.isoformat() if task.completed_at else None,
        }, default=str)
        
        await self.redis.setex(
            f"{self.queue_name}:task:{task.task_id}",
            3600,
            task_data
        )
    
    async def worker(self, worker_id: int):
        """工作线程"""
        logger.info(f"Worker {worker_id} 启动")
        
        while self._running:
            try:
                task = await self.dequeue()
                if task:
                    await self.process_task(task)
                else:
                    await asyncio.sleep(0.1)
            except Exception as e:
                logger.error(f"Worker {worker_id} 错误: {e}")
                await asyncio.sleep(1)
        
        logger.info(f"Worker {worker_id} 停止")
    
    async def start_workers(self, num_workers: int = 5):
        """启动工作线程"""
        self._running = True
        
        for i in range(num_workers):
            worker_task = asyncio.create_task(self.worker(i))
            self.workers.append(worker_task)
        
        logger.info(f"启动了 {num_workers} 个工作线程")
    
    async def get_task_status(self, task_id: str) -> Optional[Dict]:
        """获取任务状态"""
        task_data = await self.redis.get(f"{self.queue_name}:task:{task_id}")
        if task_data:
            return json.loads(task_data)
        return None
    
    async def get_queue_stats(self) -> Dict:
        """获取队列统计"""
        pending_count = await self.redis.llen(f"{self.queue_name}:pending")
        
        return {
            'pending': pending_count,
            'workers': len(self.workers),
            'handlers': list(self.handlers.keys()),
        }

