"""
IP过滤（白名单/黑名单）
"""
from typing import Set, Optional
from ipaddress import ip_address, ip_network
from loguru import logger


class IPFilter:
    """IP过滤器"""
    
    def __init__(self):
        self._whitelist: Set[ip_network] = set()
        self._blacklist: Set[ip_network] = set()
        self._whitelist_enabled = False
        self._blacklist_enabled = True  # 默认启用黑名单
    
    def add_whitelist(self, ip_or_cidr: str):
        """添加白名单"""
        try:
            if '/' in ip_or_cidr:
                network = ip_network(ip_or_cidr, strict=False)
            else:
                network = ip_network(f"{ip_or_cidr}/32", strict=False)
            self._whitelist.add(network)
            logger.info(f"添加白名单: {ip_or_cidr}")
        except ValueError as e:
            logger.error(f"无效的IP地址或CIDR: {ip_or_cidr}, {e}")
    
    def remove_whitelist(self, ip_or_cidr: str):
        """移除白名单"""
        try:
            if '/' in ip_or_cidr:
                network = ip_network(ip_or_cidr, strict=False)
            else:
                network = ip_network(f"{ip_or_cidr}/32", strict=False)
            self._whitelist.discard(network)
            logger.info(f"移除白名单: {ip_or_cidr}")
        except ValueError as e:
            logger.error(f"无效的IP地址或CIDR: {ip_or_cidr}, {e}")
    
    def add_blacklist(self, ip_or_cidr: str):
        """添加黑名单"""
        try:
            if '/' in ip_or_cidr:
                network = ip_network(ip_or_cidr, strict=False)
            else:
                network = ip_network(f"{ip_or_cidr}/32", strict=False)
            self._blacklist.add(network)
            logger.info(f"添加黑名单: {ip_or_cidr}")
        except ValueError as e:
            logger.error(f"无效的IP地址或CIDR: {ip_or_cidr}, {e}")
    
    def remove_blacklist(self, ip_or_cidr: str):
        """移除黑名单"""
        try:
            if '/' in ip_or_cidr:
                network = ip_network(ip_or_cidr, strict=False)
            else:
                network = ip_network(f"{ip_or_cidr}/32", strict=False)
            self._blacklist.discard(network)
            logger.info(f"移除黑名单: {ip_or_cidr}")
        except ValueError as e:
            logger.error(f"无效的IP地址或CIDR: {ip_or_cidr}, {e}")
    
    def is_allowed(self, ip: str) -> bool:
        """检查IP是否允许访问"""
        try:
            ip_obj = ip_address(ip)
        except ValueError:
            logger.warning(f"无效的IP地址: {ip}")
            return False
        
        # 检查黑名单
        if self._blacklist_enabled:
            for network in self._blacklist:
                if ip_obj in network:
                    logger.warning(f"IP {ip} 在黑名单中")
                    return False
        
        # 检查白名单
        if self._whitelist_enabled:
            if not self._whitelist:
                # 白名单为空且启用，拒绝所有
                return False
            for network in self._whitelist:
                if ip_obj in network:
                    return True
            logger.warning(f"IP {ip} 不在白名单中")
            return False
        
        return True
    
    def get_whitelist(self) -> list:
        """获取白名单列表"""
        return [str(network) for network in self._whitelist]
    
    def get_blacklist(self) -> list:
        """获取黑名单列表"""
        return [str(network) for network in self._blacklist]
    
    def enable_whitelist(self, enabled: bool = True):
        """启用/禁用白名单"""
        self._whitelist_enabled = enabled
        logger.info(f"白名单{'启用' if enabled else '禁用'}")
    
    def enable_blacklist(self, enabled: bool = True):
        """启用/禁用黑名单"""
        self._blacklist_enabled = enabled
        logger.info(f"黑名单{'启用' if enabled else '禁用'}")

