"""
消息协议定义
"""
from enum import Enum
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any
import json


class MessageType(str, Enum):
    """消息类型"""
    # 认证相关
    AUTH_REQUEST = "auth_request"
    AUTH_RESPONSE = "auth_response"
    
    # 心跳
    HEARTBEAT = "heartbeat"
    HEARTBEAT_RESPONSE = "heartbeat_response"
    
    # 隧道相关
    TUNNEL_CREATE = "tunnel_create"
    TUNNEL_CREATE_RESPONSE = "tunnel_create_response"
    TUNNEL_CLOSE = "tunnel_close"
    TUNNEL_DATA = "tunnel_data"
    
    # 控制消息
    CONTROL = "control"
    CONTROL_RESPONSE = "control_response"
    
    # 错误
    ERROR = "error"


@dataclass
class Message:
    """消息基类"""
    type: MessageType
    id: Optional[str] = None
    data: Optional[Dict[str, Any]] = None
    timestamp: Optional[float] = None
    
    def to_json(self) -> str:
        """转换为JSON字符串"""
        return json.dumps({
            'type': self.type.value,
            'id': self.id,
            'data': self.data or {},
            'timestamp': self.timestamp,
        })
    
    @classmethod
    def from_json(cls, json_str: str) -> 'Message':
        """从JSON字符串创建消息"""
        obj = json.loads(json_str)
        return cls(
            type=MessageType(obj['type']),
            id=obj.get('id'),
            data=obj.get('data'),
            timestamp=obj.get('timestamp'),
        )


@dataclass
class AuthRequest(Message):
    """认证请求"""
    def __init__(self, token: str, client_id: str, version: str = "1.0"):
        super().__init__(
            type=MessageType.AUTH_REQUEST,
            data={
                'token': token,
                'client_id': client_id,
                'version': version,
            }
        )


@dataclass
class AuthResponse(Message):
    """认证响应"""
    def __init__(self, success: bool, message: str = "", client_id: str = ""):
        super().__init__(
            type=MessageType.AUTH_RESPONSE,
            data={
                'success': success,
                'message': message,
                'client_id': client_id,
            }
        )


class TunnelData(Message):
    """隧道数据"""
    def __init__(self, tunnel_id: str, data: bytes):
        import base64
        super().__init__(
            type=MessageType.TUNNEL_DATA,
            data={
                'tunnel_id': tunnel_id,
                'data': base64.b64encode(data).decode('utf-8'),
            }
        )
        self.tunnel_id = tunnel_id
        self._raw_data = data
    
    def get_data(self) -> bytes:
        """获取原始数据"""
        import base64
        if hasattr(self, '_raw_data'):
            return self._raw_data
        if self.data and 'data' in self.data:
            return base64.b64decode(self.data['data'])
        return b''
    
    @classmethod
    def from_message(cls, message: Message) -> 'TunnelData':
        """从Message创建TunnelData"""
        import base64
        tunnel_id = message.data.get('tunnel_id', '')
        data_str = message.data.get('data', '')
        data = base64.b64decode(data_str) if data_str else b''
        return cls(tunnel_id=tunnel_id, data=data)

