"""
异常检测器
"""
import time
from typing import Dict, List, Optional
from collections import defaultdict, deque
from dataclasses import dataclass
from datetime import datetime
from loguru import logger


@dataclass
class AnomalyEvent:
    """异常事件"""
    type: str
    severity: str  # low, medium, high, critical
    message: str
    timestamp: float
    metadata: Optional[Dict] = None


class AnomalyDetector:
    """异常检测器"""
    
    def __init__(self):
        self._events: deque = deque(maxlen=10000)  # 保留最近10000个事件
        self._client_behaviors: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
        self._ip_behaviors: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
        self._thresholds = {
            'connection_frequency': 100,  # 每分钟最大连接数
            'failed_auth': 5,  # 每分钟最大失败认证次数
            'request_frequency': 1000,  # 每分钟最大请求数
        }
    
    def record_event(self, event_type: str, severity: str, message: str, metadata: Dict = None):
        """记录事件"""
        event = AnomalyEvent(
            type=event_type,
            severity=severity,
            message=message,
            timestamp=time.time(),
            metadata=metadata or {}
        )
        self._events.append(event)
        
        if severity in ['high', 'critical']:
            logger.warning(f"异常事件 [{severity}]: {message}")
    
    def check_connection_frequency(self, client_id: str, ip: str) -> bool:
        """检查连接频率"""
        current_time = time.time()
        window_start = current_time - 60  # 1分钟窗口
        
        # 清理过期记录
        if client_id in self._client_behaviors:
            self._client_behaviors[client_id] = deque(
                [ts for ts in self._client_behaviors[client_id] if ts > window_start],
                maxlen=1000
            )
        
        if ip in self._ip_behaviors:
            self._ip_behaviors[ip] = deque(
                [ts for ts in self._ip_behaviors[ip] if ts > window_start],
                maxlen=1000
            )
        
        # 检查客户端连接频率
        client_count = len(self._client_behaviors[client_id])
        if client_count > self._thresholds['connection_frequency']:
            self.record_event(
                'high_connection_frequency',
                'high',
                f"客户端 {client_id} 连接频率过高: {client_count}/分钟",
                {'client_id': client_id, 'count': client_count}
            )
            return False
        
        # 检查IP连接频率
        ip_count = len(self._ip_behaviors[ip])
        if ip_count > self._thresholds['connection_frequency']:
            self.record_event(
                'high_ip_connection_frequency',
                'high',
                f"IP {ip} 连接频率过高: {ip_count}/分钟",
                {'ip': ip, 'count': ip_count}
            )
            return False
        
        # 记录连接
        self._client_behaviors[client_id].append(current_time)
        self._ip_behaviors[ip].append(current_time)
        
        return True
    
    def check_failed_auth(self, client_id: str, ip: str) -> bool:
        """检查失败认证次数"""
        current_time = time.time()
        window_start = current_time - 60  # 1分钟窗口
        
        # 统计失败认证
        failed_count = sum(
            1 for event in self._events
            if event.type == 'failed_auth' and
            event.timestamp > window_start and
            (event.metadata.get('client_id') == client_id or event.metadata.get('ip') == ip)
        )
        
        if failed_count > self._thresholds['failed_auth']:
            self.record_event(
                'too_many_failed_auth',
                'critical',
                f"失败认证次数过多: {failed_count}次/分钟",
                {'client_id': client_id, 'ip': ip, 'count': failed_count}
            )
            return False
        
        return True
    
    def record_failed_auth(self, client_id: str, ip: str, reason: str):
        """记录失败认证"""
        self.record_event(
            'failed_auth',
            'medium',
            f"认证失败: {reason}",
            {'client_id': client_id, 'ip': ip, 'reason': reason}
        )
    
    def get_recent_anomalies(self, minutes: int = 60, severity: Optional[str] = None) -> List[Dict]:
        """获取最近的异常事件"""
        from ..alerting.alert_manager import AlertLevel
        
        cutoff_time = time.time() - (minutes * 60)
        
        events = [
            event for event in self._events
            if event.timestamp > cutoff_time and
            (severity is None or event.severity == severity)
        ]
        
        return [
            {
                'type': event.type,
                'severity': event.severity,
                'message': event.message,
                'timestamp': datetime.fromtimestamp(event.timestamp).isoformat(),
                'metadata': event.metadata or {},
            }
            for event in sorted(events, key=lambda x: x.timestamp, reverse=True)
        ]
    
    def get_client_behavior_stats(self, client_id: str) -> Dict:
        """获取客户端行为统计"""
        current_time = time.time()
        window_start = current_time - 3600  # 1小时窗口
        
        if client_id not in self._client_behaviors:
            return {'connections': 0, 'failed_auth': 0}
        
        connections = sum(
            1 for ts in self._client_behaviors[client_id]
            if ts > window_start
        )
        
        failed_auth = sum(
            1 for event in self._events
            if event.type == 'failed_auth' and
            event.timestamp > window_start and
            event.metadata.get('client_id') == client_id
        )
        
        return {
            'connections': connections,
            'failed_auth': failed_auth,
        }

