"""
Prometheus指标导出
"""
from prometheus_client import Counter, Gauge, Histogram, generate_latest, REGISTRY
from typing import Optional
from loguru import logger


class PrometheusMetrics:
    """Prometheus指标"""
    
    def __init__(self):
        # 连接指标
        self.connections_total = Counter(
            'nps_connections_total',
            'Total number of connections',
            ['status']
        )
        
        self.connections_active = Gauge(
            'nps_connections_active',
            'Number of active connections'
        )
        
        # 隧道指标
        self.tunnels_total = Gauge(
            'nps_tunnels_total',
            'Total number of tunnels',
            ['status']
        )
        
        self.tunnels_active = Gauge(
            'nps_tunnels_active',
            'Number of active tunnels'
        )
        
        # 流量指标
        self.traffic_bytes = Counter(
            'nps_traffic_bytes_total',
            'Total traffic in bytes',
            ['direction', 'client_id', 'tunnel_id']
        )
        
        # 请求指标
        self.http_requests_total = Counter(
            'nps_http_requests_total',
            'Total HTTP requests',
            ['method', 'endpoint', 'status']
        )
        
        self.http_request_duration = Histogram(
            'nps_http_request_duration_seconds',
            'HTTP request duration',
            ['method', 'endpoint']
        )
        
        # 错误指标
        self.errors_total = Counter(
            'nps_errors_total',
            'Total number of errors',
            ['type', 'severity']
        )
        
        # 认证指标
        self.auth_attempts_total = Counter(
            'nps_auth_attempts_total',
            'Total authentication attempts',
            ['status']
        )
        
        # 系统指标
        self.system_cpu_percent = Gauge(
            'nps_system_cpu_percent',
            'System CPU usage percentage'
        )
        
        self.system_memory_bytes = Gauge(
            'nps_system_memory_bytes',
            'System memory usage in bytes',
            ['type']  # used, total, available
        )
        
        self.system_disk_bytes = Gauge(
            'nps_system_disk_bytes',
            'System disk usage in bytes',
            ['type', 'mountpoint']  # used, total, available
        )
        
        self.system_network_bytes = Counter(
            'nps_system_network_bytes_total',
            'System network traffic in bytes',
            ['direction', 'interface']  # sent, received
        )
        
        # 限流指标
        self.rate_limit_hits_total = Counter(
            'nps_rate_limit_hits_total',
            'Total rate limit hits',
            ['endpoint', 'client_key']
        )
        
        # 熔断器指标
        self.circuit_breaker_state = Gauge(
            'nps_circuit_breaker_state',
            'Circuit breaker state',
            ['name']  # 0=closed, 1=open, 2=half_open
        )
        
        self.circuit_breaker_failures_total = Counter(
            'nps_circuit_breaker_failures_total',
            'Total circuit breaker failures',
            ['name']
        )
        
        # 性能指标
        self.request_queue_size = Gauge(
            'nps_request_queue_size',
            'Request queue size'
        )
        
        self.response_time_p95 = Histogram(
            'nps_response_time_p95_seconds',
            '95th percentile response time',
            ['endpoint']
        )
        
        self.response_time_p99 = Histogram(
            'nps_response_time_p99_seconds',
            '99th percentile response time',
            ['endpoint']
        )
        
        # 安全指标
        self.security_blocked_ips_total = Counter(
            'nps_security_blocked_ips_total',
            'Total blocked IPs',
            ['reason']  # blacklist, rate_limit, anomaly
        )
        
        self.security_anomalies_total = Counter(
            'nps_security_anomalies_total',
            'Total security anomalies detected',
            ['type']  # ddos, brute_force, suspicious_pattern
        )
        
        # 配额指标
        self.quota_usage = Gauge(
            'nps_quota_usage',
            'Quota usage percentage',
            ['client_id', 'quota_type']  # traffic, connections, time
        )
        
        logger.info("Prometheus指标初始化完成")
    
    def record_connection(self, status: str = "connected"):
        """记录连接"""
        self.connections_total.labels(status=status).inc()
    
    def set_active_connections(self, count: int):
        """设置活跃连接数"""
        self.connections_active.set(count)
    
    def set_tunnels(self, total: int, active: int):
        """设置隧道数"""
        self.tunnels_total.labels(status='total').set(total)
        self.tunnels_active.set(active)
    
    def record_traffic(self, bytes_count: int, direction: str, client_id: str = "", tunnel_id: str = ""):
        """记录流量"""
        self.traffic_bytes.labels(
            direction=direction,
            client_id=client_id or "unknown",
            tunnel_id=tunnel_id or "unknown"
        ).inc(bytes_count)
    
    def record_http_request(self, method: str, endpoint: str, status_code: int, duration: float):
        """记录HTTP请求"""
        self.http_requests_total.labels(
            method=method,
            endpoint=endpoint,
            status=str(status_code)
        ).inc()
        
        self.http_request_duration.labels(
            method=method,
            endpoint=endpoint
        ).observe(duration)
    
    def record_error(self, error_type: str, severity: str = "medium"):
        """记录错误"""
        self.errors_total.labels(type=error_type, severity=severity).inc()
    
    def record_auth_attempt(self, status: str):
        """记录认证尝试"""
        self.auth_attempts_total.labels(status=status).inc()
    
    def set_system_metrics(self, cpu_percent: float, memory_used: int, memory_total: int, memory_available: int):
        """设置系统指标"""
        self.system_cpu_percent.set(cpu_percent)
        self.system_memory_bytes.labels(type='used').set(memory_used)
        self.system_memory_bytes.labels(type='total').set(memory_total)
        self.system_memory_bytes.labels(type='available').set(memory_available)
    
    def set_disk_metrics(self, used: int, total: int, available: int, mountpoint: str = "/"):
        """设置磁盘指标"""
        self.system_disk_bytes.labels(type='used', mountpoint=mountpoint).set(used)
        self.system_disk_bytes.labels(type='total', mountpoint=mountpoint).set(total)
        self.system_disk_bytes.labels(type='available', mountpoint=mountpoint).set(available)
    
    def record_network_traffic(self, bytes_count: int, direction: str, interface: str = "eth0"):
        """记录网络流量"""
        self.system_network_bytes.labels(direction=direction, interface=interface).inc(bytes_count)
    
    def record_rate_limit_hit(self, endpoint: str, client_key: str):
        """记录限流命中"""
        self.rate_limit_hits_total.labels(endpoint=endpoint, client_key=client_key).inc()
    
    def set_circuit_breaker_state(self, name: str, state: int):
        """设置熔断器状态 (0=closed, 1=open, 2=half_open)"""
        self.circuit_breaker_state.labels(name=name).set(state)
    
    def record_circuit_breaker_failure(self, name: str):
        """记录熔断器失败"""
        self.circuit_breaker_failures_total.labels(name=name).inc()
    
    def set_request_queue_size(self, size: int):
        """设置请求队列大小"""
        self.request_queue_size.set(size)
    
    def record_response_time_percentile(self, endpoint: str, p95: float, p99: float):
        """记录响应时间百分位数"""
        self.response_time_p95.labels(endpoint=endpoint).observe(p95)
        self.response_time_p99.labels(endpoint=endpoint).observe(p99)
    
    def record_blocked_ip(self, reason: str):
        """记录被阻止的IP"""
        self.security_blocked_ips_total.labels(reason=reason).inc()
    
    def record_security_anomaly(self, anomaly_type: str):
        """记录安全异常"""
        self.security_anomalies_total.labels(type=anomaly_type).inc()
    
    def set_quota_usage(self, client_id: str, quota_type: str, usage_percent: float):
        """设置配额使用率"""
        self.quota_usage.labels(client_id=client_id, quota_type=quota_type).set(usage_percent)
    
    def get_metrics(self) -> bytes:
        """获取指标数据（Prometheus格式）"""
        return generate_latest(REGISTRY)

