import ast
import hashlib
import re
from collections import OrderedDict

from src.errors import SecurityViolationError
from src.import_validation import validate_module_import
from src.config.security_config import SecurityConfig
from src.constants import (
    MAX_VALIDATION_CACHE_SIZE,
    ERROR_RELATIVE_IMPORT,
    ERROR_DANGEROUS_NAME,
    ERROR_DANGEROUS_ATTRIBUTE,
    ERROR_NAME_MANGLED_ATTRIBUTE,
    ERROR_DYNAMIC_IMPORT,
    ERROR_DANGEROUS_STRING_PATTERN,
    ERROR_MATCH_PATTERN_ATTRIBUTE,
    BLOCKED_ATTRIBUTES,
    BLOCKED_NAMES,
)

CacheKey = tuple[str, tuple]  # (code_hash, allowlists_tuple)
CachedViolations = list[str]
ValidationCache = OrderedDict[CacheKey, CachedViolations]

FORMAT_FIELD_PATTERN = re.compile(r"\{([^}]*)\}")


class SecurityValidator(ast.NodeVisitor):
    """AST visitor that enforces import allowlists and blocks dangerous attribute access."""

    def __init__(self, security_config: SecurityConfig):
        self.checked_modules: set[str] = set()
        self.violations: list[str] = []
        self.security_config = security_config

    # ========== Detection ==========

    def visit_Import(self, node: ast.Import) -> None:
        """Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np)."""

        for alias in node.names:
            module_name = alias.name
            self._validate_import(module_name, node.lineno)
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        """Detect from import statements (e.g., from os import path)."""

        if node.level > 0:
            self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT)
        elif node.module:
            self._validate_import(node.module, node.lineno)

        self.generic_visit(node)

    def visit_Name(self, node: ast.Name) -> None:
        if node.id in BLOCKED_NAMES:
            self._add_violation(node.lineno, ERROR_DANGEROUS_NAME.format(name=node.id))

        self.generic_visit(node)

    def visit_Attribute(self, node: ast.Attribute) -> None:
        """Detect access to unsafe attributes that could bypass security restrictions."""

        if node.attr in BLOCKED_ATTRIBUTES:
            self._add_violation(
                node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr)
            )

        if node.attr.startswith("_") and "__" in node.attr:
            parts = node.attr.split("__", 1)
            if len(parts) == 2 and parts[0].startswith("_"):
                self._add_violation(node.lineno, ERROR_NAME_MANGLED_ATTRIBUTE)

        self.generic_visit(node)

    def visit_Call(self, node: ast.Call) -> None:
        """Detect calls to __import__() that could bypass security restrictions."""

        is_import_call = (
            # __import__()
            (isinstance(node.func, ast.Name) and node.func.id == "__import__")
            or
            # builtins.__import__() or __builtins__.__import__()
            (
                isinstance(node.func, ast.Attribute)
                and node.func.attr == "__import__"
                and isinstance(node.func.value, ast.Name)
                and node.func.value.id in {"builtins", "__builtins__"}
            )
        )

        if is_import_call:
            if (
                node.args
                and isinstance(node.args[0], ast.Constant)
                and isinstance(node.args[0].value, str)
            ):
                module_name = node.args[0].value
                self._validate_import(module_name, node.lineno)
            else:
                self._add_violation(node.lineno, ERROR_DYNAMIC_IMPORT)

        self.generic_visit(node)

    def visit_Subscript(self, node: ast.Subscript) -> None:
        """Detect dict access to blocked attributes, e.g. __builtins__['__spec__']"""

        is_builtins_access = (
            # __builtins__['__spec__']
            (
                isinstance(node.value, ast.Name)
                and node.value.id in {"__builtins__", "builtins"}
            )
            # obj.__builtins__['__spec__']
            or (
                isinstance(node.value, ast.Attribute)
                and node.value.attr in {"__builtins__", "builtins"}
            )
        )

        if (
            is_builtins_access
            and isinstance(node.slice, ast.Constant)
            and isinstance(node.slice.value, str)
        ):
            key = node.slice.value
            if key in BLOCKED_ATTRIBUTES:
                self._add_violation(
                    node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=key)
                )

        self.generic_visit(node)

    def visit_Constant(self, node: ast.Constant) -> None:
        """Detect string constants containing dangerous format patterns."""

        if isinstance(node.value, str):
            self._check_format_string(node.value, node.lineno)

        self.generic_visit(node)

    def visit_MatchClass(self, node: ast.MatchClass) -> None:
        """Detect match patterns that extract blocked attributes, e.g. `case AttributeError(obj=x)`"""

        for attr in node.kwd_attrs:
            if attr in BLOCKED_ATTRIBUTES:
                self._add_violation(
                    node.lineno, ERROR_MATCH_PATTERN_ATTRIBUTE.format(attr=attr)
                )

        self.generic_visit(node)

    def _check_format_string(self, s: str, lineno: int) -> None:
        """Check if a string contains format patterns that access blocked attributes."""

        # escaped braces produce literal braces, not format fields
        s = s.replace("{{", "").replace("}}", "")

        for match in FORMAT_FIELD_PATTERN.finditer(s):
            field = match.group(1)

            # attribute access
            for attr_match in re.finditer(r"\.(\w+)", field):
                attr = attr_match.group(1)
                if attr in BLOCKED_ATTRIBUTES or attr in BLOCKED_NAMES:
                    self._add_violation(
                        lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=attr)
                    )

            # subscript access
            for subscript_match in re.finditer(r"\[(['\"]?)(\w+)\1\]", field):
                key = subscript_match.group(2)
                if key in BLOCKED_ATTRIBUTES or key in BLOCKED_NAMES:
                    self._add_violation(
                        lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=key)
                    )

    # ========== Validation ==========

    def _validate_import(self, module_path: str, lineno: int) -> None:
        """Validate that a module import is allowed based on allowlists. Also disallow relative imports."""

        if module_path.startswith("."):
            self._add_violation(lineno, ERROR_RELATIVE_IMPORT)
            return

        module_name = module_path.split(".")[0]  # e.g., os.path -> os

        if module_name in self.checked_modules:
            return

        self.checked_modules.add(module_name)

        is_allowed, error_msg = validate_module_import(
            module_path, self.security_config
        )

        if not is_allowed:
            assert error_msg is not None
            self._add_violation(lineno, error_msg)

    def _add_violation(self, lineno: int, message: str) -> None:
        self.violations.append(f"Line {lineno}: {message}")


class TaskAnalyzer:
    _cache: ValidationCache = OrderedDict()

    def __init__(self, security_config: SecurityConfig):
        self._security_config = security_config
        self._allowlists = (
            tuple(sorted(security_config.stdlib_allow)),
            tuple(sorted(security_config.external_allow)),
        )
        self._allow_all = (
            "*" in security_config.stdlib_allow
            and "*" in security_config.external_allow
        )

    def validate(self, code: str) -> None:
        if self._allow_all:
            return

        cache_key = self._to_cache_key(code)
        cached_violations = self._cache.get(cache_key)

        if cached_violations is not None:
            self._cache.move_to_end(cache_key)

            if len(cached_violations) == 0:
                return

            self._raise_security_error(cached_violations)

        tree = ast.parse(code)

        security_validator = SecurityValidator(self._security_config)
        security_validator.visit(tree)

        self._set_in_cache(cache_key, security_validator.violations)

        if security_validator.violations:
            self._raise_security_error(security_validator.violations)

    def _raise_security_error(self, violations: CachedViolations) -> None:
        raise SecurityViolationError(
            message="Security violations detected", description="\n".join(violations)
        )

    def _to_cache_key(self, code: str) -> CacheKey:
        code_hash = hashlib.sha256(code.encode()).hexdigest()
        return (code_hash, self._allowlists)

    def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None:
        if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE:
            self._cache.popitem(last=False)  # FIFO

        self._cache[cache_key] = violations.copy()
        self._cache.move_to_end(cache_key)
