"""
日志审计 - 安全日志审计功能
"""
import os
import json
import hashlib
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
from loguru import logger
from pathlib import Path


class AuditEventType(str, Enum):
    """审计事件类型"""
    AUTHENTICATION = "authentication"
    AUTHORIZATION = "authorization"
    DATA_ACCESS = "data_access"
    CONFIG_CHANGE = "config_change"
    SECURITY_EVENT = "security_event"
    SYSTEM_EVENT = "system_event"


@dataclass
class AuditLogEntry:
    """审计日志条目"""
    event_id: str
    event_type: AuditEventType
    timestamp: datetime
    user_id: Optional[str] = None
    client_ip: Optional[str] = None
    action: str = ""
    resource: str = ""
    status: str = ""  # success, failure
    details: Dict = None
    
    def __post_init__(self):
        if self.details is None:
            self.details = {}


class AuditLogger:
    """审计日志记录器"""
    
    def __init__(self, log_dir: str = "logs/audit"):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.log_file = self.log_dir / f"audit_{datetime.now().strftime('%Y%m%d')}.log"
        self._setup_logger()
    
    def _setup_logger(self):
        """设置日志记录器"""
        from loguru import logger as loguru_logger
        
        loguru_logger.add(
            str(self.log_file),
            rotation="1 day",
            retention="90 days",
            compression="zip",
            format="{time} | {level} | {message}",
            level="INFO"
        )
    
    def _generate_event_id(self) -> str:
        """生成事件ID"""
        timestamp = datetime.now().isoformat()
        random_str = os.urandom(8).hex()
        return hashlib.sha256(f"{timestamp}{random_str}".encode()).hexdigest()[:16]
    
    def log_event(self, event_type: AuditEventType, action: str, 
                  user_id: str = None, client_ip: str = None,
                  resource: str = "", status: str = "success",
                  details: Dict = None):
        """记录审计事件"""
        event_id = self._generate_event_id()
        
        entry = AuditLogEntry(
            event_id=event_id,
            event_type=event_type,
            timestamp=datetime.now(),
            user_id=user_id,
            client_ip=client_ip,
            action=action,
            resource=resource,
            status=status,
            details=details or {}
        )
        
        # 记录到日志文件
        self._write_log_entry(entry)
        
        # 记录到loguru
        logger.info(f"[AUDIT] {event_type.value} | {action} | {status} | {user_id} | {client_ip}")
        
        return event_id
    
    def _write_log_entry(self, entry: AuditLogEntry):
        """写入日志条目"""
        try:
            log_data = {
                'event_id': entry.event_id,
                'event_type': entry.event_type.value,
                'timestamp': entry.timestamp.isoformat(),
                'user_id': entry.user_id,
                'client_ip': entry.client_ip,
                'action': entry.action,
                'resource': entry.resource,
                'status': entry.status,
                'details': entry.details
            }
            
            with open(self.log_file, 'a', encoding='utf-8') as f:
                f.write(json.dumps(log_data, ensure_ascii=False) + '\n')
        except Exception as e:
            logger.error(f"写入审计日志失败: {e}")
    
    def log_authentication(self, user_id: str, client_ip: str, status: str, details: Dict = None):
        """记录认证事件"""
        return self.log_event(
            AuditEventType.AUTHENTICATION,
            "login",
            user_id=user_id,
            client_ip=client_ip,
            status=status,
            details=details
        )
    
    def log_authorization(self, user_id: str, action: str, resource: str, 
                         status: str, client_ip: str = None, details: Dict = None):
        """记录授权事件"""
        return self.log_event(
            AuditEventType.AUTHORIZATION,
            action,
            user_id=user_id,
            client_ip=client_ip,
            resource=resource,
            status=status,
            details=details
        )
    
    def log_data_access(self, user_id: str, action: str, resource: str,
                       client_ip: str = None, details: Dict = None):
        """记录数据访问事件"""
        return self.log_event(
            AuditEventType.DATA_ACCESS,
            action,
            user_id=user_id,
            client_ip=client_ip,
            resource=resource,
            status="success",
            details=details
        )
    
    def log_config_change(self, user_id: str, action: str, resource: str,
                         client_ip: str = None, details: Dict = None):
        """记录配置变更事件"""
        return self.log_event(
            AuditEventType.CONFIG_CHANGE,
            action,
            user_id=user_id,
            client_ip=client_ip,
            resource=resource,
            status="success",
            details=details
        )
    
    def log_security_event(self, event_type: str, severity: str,
                          client_ip: str = None, details: Dict = None):
        """记录安全事件"""
        return self.log_event(
            AuditEventType.SECURITY_EVENT,
            event_type,
            client_ip=client_ip,
            status=severity,
            details=details
        )
    
    def query_logs(self, start_time: datetime = None, end_time: datetime = None,
                  event_type: AuditEventType = None, user_id: str = None,
                  client_ip: str = None, limit: int = 100) -> List[Dict]:
        """查询审计日志"""
        results = []
        
        try:
            # 确定要查询的日志文件
            if start_time:
                start_date = start_time.date()
            else:
                start_date = datetime.now().date() - timedelta(days=7)
            
            if end_time:
                end_date = end_time.date()
            else:
                end_date = datetime.now().date()
            
            # 遍历日期范围内的日志文件
            current_date = start_date
            while current_date <= end_date:
                log_file = self.log_dir / f"audit_{current_date.strftime('%Y%m%d')}.log"
                
                if log_file.exists():
                    with open(log_file, 'r', encoding='utf-8') as f:
                        for line in f:
                            try:
                                entry = json.loads(line.strip())
                                
                                # 过滤条件
                                entry_time = datetime.fromisoformat(entry['timestamp'])
                                if start_time and entry_time < start_time:
                                    continue
                                if end_time and entry_time > end_time:
                                    continue
                                if event_type and entry['event_type'] != event_type.value:
                                    continue
                                if user_id and entry.get('user_id') != user_id:
                                    continue
                                if client_ip and entry.get('client_ip') != client_ip:
                                    continue
                                
                                results.append(entry)
                                
                                if len(results) >= limit:
                                    break
                            except json.JSONDecodeError:
                                continue
                
                if len(results) >= limit:
                    break
                
                current_date += timedelta(days=1)
        
        except Exception as e:
            logger.error(f"查询审计日志失败: {e}")
        
        return results[:limit]
    
    def get_statistics(self, days: int = 7) -> Dict:
        """获取统计信息"""
        end_time = datetime.now()
        start_time = end_time - timedelta(days=days)
        
        logs = self.query_logs(start_time=start_time, end_time=end_time, limit=10000)
        
        stats = {
            'total_events': len(logs),
            'by_type': {},
            'by_status': {},
            'by_user': {},
            'failed_authentications': 0,
            'security_events': 0,
        }
        
        for log in logs:
            # 按类型统计
            event_type = log.get('event_type', 'unknown')
            stats['by_type'][event_type] = stats['by_type'].get(event_type, 0) + 1
            
            # 按状态统计
            status = log.get('status', 'unknown')
            stats['by_status'][status] = stats['by_status'].get(status, 0) + 1
            
            # 按用户统计
            user_id = log.get('user_id')
            if user_id:
                stats['by_user'][user_id] = stats['by_user'].get(user_id, 0) + 1
            
            # 统计失败认证
            if event_type == 'authentication' and status == 'failure':
                stats['failed_authentications'] += 1
            
            # 统计安全事件
            if event_type == 'security_event':
                stats['security_events'] += 1
        
        return stats

