"""
分布式追踪 - OpenTelemetry集成
"""
import time
from typing import Dict, Optional
from contextlib import contextmanager
from loguru import logger


class Span:
    """追踪跨度"""
    
    def __init__(self, trace_id: str, span_id: str, name: str, parent_id: Optional[str] = None):
        self.trace_id = trace_id
        self.span_id = span_id
        self.name = name
        self.parent_id = parent_id
        self.start_time = time.time()
        self.end_time: Optional[float] = None
        self.tags: Dict[str, str] = {}
        self.events: List[Dict] = []
        self.status = "ok"
        self.error: Optional[str] = None
    
    def set_tag(self, key: str, value: str):
        """设置标签"""
        self.tags[key] = value
    
    def add_event(self, name: str, attributes: Dict = None):
        """添加事件"""
        self.events.append({
            'name': name,
            'timestamp': time.time(),
            'attributes': attributes or {}
        })
    
    def set_status(self, status: str, error: Optional[str] = None):
        """设置状态"""
        self.status = status
        if error:
            self.error = error
    
    def finish(self):
        """完成跨度"""
        self.end_time = time.time()
    
    def duration(self) -> float:
        """获取持续时间"""
        if self.end_time:
            return self.end_time - self.start_time
        return time.time() - self.start_time


class Tracer:
    """分布式追踪器"""
    
    def __init__(self, service_name: str = "nps-server"):
        self.service_name = service_name
        self.spans: Dict[str, Span] = {}
        self._enabled = True
    
    def enable(self):
        """启用追踪"""
        self._enabled = True
    
    def disable(self):
        """禁用追踪"""
        self._enabled = False
    
    def start_span(self, name: str, parent_span_id: Optional[str] = None, 
                   trace_id: Optional[str] = None) -> Span:
        """开始跨度"""
        if not self._enabled:
            return None
        
        import uuid
        
        if trace_id is None:
            trace_id = str(uuid.uuid4())
        
        span_id = str(uuid.uuid4())
        
        span = Span(
            trace_id=trace_id,
            span_id=span_id,
            name=name,
            parent_id=parent_span_id
        )
        
        span.set_tag("service.name", self.service_name)
        
        self.spans[span_id] = span
        
        return span
    
    @contextmanager
    def span(self, name: str, parent_span_id: Optional[str] = None,
             trace_id: Optional[str] = None, tags: Dict[str, str] = None):
        """跨度上下文管理器"""
        span = self.start_span(name, parent_span_id, trace_id)
        
        if span and tags:
            for key, value in tags.items():
                span.set_tag(key, value)
        
        try:
            yield span
        except Exception as e:
            if span:
                span.set_status("error", str(e))
            raise
        finally:
            if span:
                span.finish()
    
    def get_span(self, span_id: str) -> Optional[Span]:
        """获取跨度"""
        return self.spans.get(span_id)
    
    def get_trace(self, trace_id: str) -> List[Span]:
        """获取追踪的所有跨度"""
        return [span for span in self.spans.values() if span.trace_id == trace_id]
    
    def export_trace(self, trace_id: str) -> Dict:
        """导出追踪数据"""
        spans = self.get_trace(trace_id)
        
        return {
            'trace_id': trace_id,
            'service_name': self.service_name,
            'spans': [
                {
                    'span_id': span.span_id,
                    'name': span.name,
                    'parent_id': span.parent_id,
                    'start_time': span.start_time,
                    'end_time': span.end_time,
                    'duration': span.duration(),
                    'tags': span.tags,
                    'events': span.events,
                    'status': span.status,
                    'error': span.error,
                }
                for span in spans
            ]
        }


class TraceMiddleware:
    """追踪中间件"""
    
    def __init__(self, tracer: Tracer):
        self.tracer = tracer
    
    async def __call__(self, request, call_next):
        """中间件处理"""
        trace_id = request.headers.get("X-Trace-Id")
        
        with self.tracer.span(
            name=f"{request.method} {request.url.path}",
            trace_id=trace_id,
            tags={
                "http.method": request.method,
                "http.path": str(request.url.path),
                "http.query": str(request.url.query),
            }
        ) as span:
            start_time = time.time()
            
            try:
                response = await call_next(request)
                
                duration = time.time() - start_time
                span.set_tag("http.status_code", str(response.status_code))
                span.set_tag("http.duration", f"{duration:.3f}s")
                
                # 添加追踪ID到响应头
                response.headers["X-Trace-Id"] = span.trace_id
                
                return response
            except Exception as e:
                span.set_status("error", str(e))
                raise

