"""
密钥轮换 - 自动密钥轮换机制
"""
import os
import secrets
import hashlib
from typing import Dict, Optional, List
from datetime import datetime, timedelta
from dataclasses import dataclass
from loguru import logger
import json


@dataclass
class KeyInfo:
    """密钥信息"""
    key_id: str
    key_value: str
    created_at: datetime
    expires_at: Optional[datetime] = None
    rotated_at: Optional[datetime] = None
    active: bool = True


class KeyRotationManager:
    """密钥轮换管理器"""
    
    def __init__(self, key_store_file: str = "keys/rotation_store.json"):
        self.key_store_file = key_store_file
        os.makedirs(os.path.dirname(key_store_file), exist_ok=True)
        self.keys: Dict[str, KeyInfo] = {}
        self._load_keys()
    
    def _load_keys(self):
        """加载密钥"""
        try:
            if os.path.exists(self.key_store_file):
                with open(self.key_store_file, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    for key_id, key_data in data.items():
                        self.keys[key_id] = KeyInfo(
                            key_id=key_data['key_id'],
                            key_value=key_data['key_value'],
                            created_at=datetime.fromisoformat(key_data['created_at']),
                            expires_at=datetime.fromisoformat(key_data['expires_at']) if key_data.get('expires_at') else None,
                            rotated_at=datetime.fromisoformat(key_data['rotated_at']) if key_data.get('rotated_at') else None,
                            active=key_data.get('active', True)
                        )
        except Exception as e:
            logger.error(f"加载密钥失败: {e}")
    
    def _save_keys(self):
        """保存密钥"""
        try:
            data = {}
            for key_id, key_info in self.keys.items():
                data[key_id] = {
                    'key_id': key_info.key_id,
                    'key_value': key_info.key_value,
                    'created_at': key_info.created_at.isoformat(),
                    'expires_at': key_info.expires_at.isoformat() if key_info.expires_at else None,
                    'rotated_at': key_info.rotated_at.isoformat() if key_info.rotated_at else None,
                    'active': key_info.active
                }
            
            with open(self.key_store_file, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2)
        except Exception as e:
            logger.error(f"保存密钥失败: {e}")
    
    def generate_key(self, key_id: str, length: int = 64) -> str:
        """生成新密钥"""
        key_value = secrets.token_urlsafe(length)
        
        key_info = KeyInfo(
            key_id=key_id,
            key_value=key_value,
            created_at=datetime.now(),
            expires_at=datetime.now() + timedelta(days=90),  # 90天过期
            active=True
        )
        
        self.keys[key_id] = key_info
        self._save_keys()
        
        logger.info(f"生成新密钥: {key_id}")
        return key_value
    
    def rotate_key(self, key_id: str, new_length: int = 64) -> str:
        """轮换密钥"""
        if key_id in self.keys:
            # 标记旧密钥为非活跃
            self.keys[key_id].active = False
            self.keys[key_id].rotated_at = datetime.now()
        
        # 生成新密钥
        new_key = self.generate_key(key_id, new_length)
        
        logger.info(f"密钥已轮换: {key_id}")
        return new_key
    
    def get_key(self, key_id: str) -> Optional[str]:
        """获取密钥"""
        if key_id in self.keys and self.keys[key_id].active:
            return self.keys[key_id].key_value
        return None
    
    def get_active_key(self, key_id: str) -> Optional[str]:
        """获取活跃密钥"""
        if key_id in self.keys:
            key_info = self.keys[key_id]
            if key_info.active:
                # 检查是否过期
                if key_info.expires_at and datetime.now() > key_info.expires_at:
                    logger.warning(f"密钥已过期: {key_id}")
                    return None
                return key_info.key_value
        return None
    
    def check_expiring_keys(self, days_before: int = 7) -> List[str]:
        """检查即将过期的密钥"""
        expiring = []
        threshold = datetime.now() + timedelta(days=days_before)
        
        for key_id, key_info in self.keys.items():
            if key_info.active and key_info.expires_at:
                if key_info.expires_at <= threshold:
                    expiring.append(key_id)
        
        return expiring
    
    def auto_rotate_expired_keys(self) -> List[str]:
        """自动轮换过期密钥"""
        rotated = []
        
        for key_id, key_info in list(self.keys.items()):
            if key_info.active and key_info.expires_at:
                if datetime.now() > key_info.expires_at:
                    self.rotate_key(key_id)
                    rotated.append(key_id)
        
        if rotated:
            logger.info(f"自动轮换了{len(rotated)}个过期密钥: {rotated}")
        
        return rotated
    
    def get_key_history(self, key_id: str) -> List[KeyInfo]:
        """获取密钥历史"""
        history = []
        for k_id, key_info in self.keys.items():
            if k_id.startswith(key_id):
                history.append(key_info)
        return sorted(history, key=lambda x: x.created_at, reverse=True)
    
    def revoke_key(self, key_id: str):
        """撤销密钥"""
        if key_id in self.keys:
            self.keys[key_id].active = False
            self._save_keys()
            logger.info(f"密钥已撤销: {key_id}")


class SecretKeyRotator:
    """密钥轮换器（用于JWT等）"""
    
    def __init__(self, rotation_manager: KeyRotationManager):
        self.rotation_manager = rotation_manager
        self.current_key_id = "jwt_secret"
    
    def get_current_key(self) -> str:
        """获取当前密钥"""
        key = self.rotation_manager.get_active_key(self.current_key_id)
        if not key:
            # 如果没有密钥，生成一个新的
            key = self.rotation_manager.generate_key(self.current_key_id)
        return key
    
    def rotate_secret_key(self) -> str:
        """轮换密钥"""
        return self.rotation_manager.rotate_key(self.current_key_id)
    
    def check_and_rotate_if_needed(self, days_before_expiry: int = 7) -> bool:
        """检查并在需要时轮换"""
        expiring = self.rotation_manager.check_expiring_keys(days_before_expiry)
        
        if self.current_key_id in expiring:
            logger.info(f"密钥即将过期，开始轮换: {self.current_key_id}")
            self.rotate_secret_key()
            return True
        
        return False

