"""Tooling configuration models."""

import hashlib
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import Any, Dict, List, Mapping, Tuple

from entity.configs.base import (
    BaseConfig,
    ConfigError,
    ConfigFieldSpec,
    EnumOption,
    ChildKey,
    ensure_list,
    optional_bool,
    optional_str,
    require_mapping,
    require_str,
    extend_path,
)
from entity.enum_options import enum_options_from_values
from utils.registry import Registry, RegistryError
from utils.function_catalog import FunctionCatalog, get_function_catalog


tooling_type_registry = Registry("tooling_type")
MODULE_ALL_SUFFIX = ":All"


def register_tooling_type(
    name: str,
    *,
    config_cls: type[BaseConfig],
    description: str | None = None,
) -> None:
    metadata = {"summary": description} if description else None
    tooling_type_registry.register(name, target=config_cls, metadata=metadata)


def get_tooling_type_config(name: str) -> type[BaseConfig]:
    entry = tooling_type_registry.get(name)
    config_cls = entry.load()
    if not isinstance(config_cls, type) or not issubclass(config_cls, BaseConfig):
        raise RegistryError(f"Entry '{name}' is not a BaseConfig subclass")
    return config_cls


def iter_tooling_type_registrations() -> Dict[str, type[BaseConfig]]:
    return {name: entry.load() for name, entry in tooling_type_registry.items()}


def iter_tooling_type_metadata() -> Dict[str, Dict[str, Any]]:
    return {name: dict(entry.metadata or {}) for name, entry in tooling_type_registry.items()}


@dataclass
class FunctionToolEntryConfig(BaseConfig):
    """Schema helper used to describe per-function options."""

    name: str | None = None
    description: str | None = None
    parameters: Dict[str, Any] | None = None
    auto_fill: bool = True

    FIELD_SPECS = {
        "name": ConfigFieldSpec(
            name="name",
            display_name="Function Name",
            type_hint="str",
            required=True,
            description="Function name from functions/function_calling directory",
        ),
        # "description": ConfigFieldSpec(
        #     name="description",
        #     display_name="Description",
        #     type_hint="str",
        #     required=False,
        #     description="Override auto-parsed function description, optional",
        #     advance=True,
        # ),
        # "parameters": ConfigFieldSpec(
        #     name="parameters",
        #     display_name="Parameter Schema",
        #     type_hint="object",
        #     required=False,
        #     description="Override JSON Schema generated from function signature, optional",
        #     advance=True,
        # ),
        # "auto_fill": ConfigFieldSpec(
        #     name="auto_fill",
        #     display_name="Auto Fill Description",
        #     type_hint="bool",
        #     required=False,
        #     default=True,
        #     description="Whether to auto-fill description/parameters based on Python function signature",
        #     advance=True,
        # ),
    }

    @classmethod
    def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
        specs = super().field_specs()
        catalog = get_function_catalog()
        modules = catalog.iter_modules()
        name_spec = specs.get("name")
        if name_spec is not None:
            description = name_spec.description or "Function name"
            enum_options: List[EnumOption] | None = None
            enum_values: List[str] | None = None
            if catalog.load_error:
                description = f"{description} (loading failed: {catalog.load_error})"
            elif not modules:
                description = f"{description} (no functions found in directory)"
            else:
                enum_options = []
                enum_values = []
                for module_name, metas in modules:
                    all_label = f"{module_name}{MODULE_ALL_SUFFIX}"
                    enum_values.append(all_label)
                    preview = ", ".join(meta.name for meta in metas[:3])
                    suffix = "..." if len(metas) > 3 else ""
                    module_hint = f"{module_name}.py"
                    enum_options.append(
                        EnumOption(
                            value=all_label,
                            label=all_label,
                            description=(
                                f"Load all {len(metas)} functions from {module_hint}"
                                + (f" ({preview}{suffix})" if preview else "")
                            ),
                        )
                    )
                for module_name, metas in modules:
                    for meta in metas:
                        label = f"{module_name}:{meta.name}"
                        enum_values.append(meta.name)
                        option_description = meta.description or "This function does not provide a docstring"
                        enum_options.append(
                            EnumOption(
                                value=meta.name,
                                label=label,
                                description=option_description,
                            )
                        )
            specs["name"] = replace(
                name_spec,
                enum=enum_values,
                enum_options=enum_options,
                description=description,
            )
        return specs


@dataclass
class FunctionToolConfig(BaseConfig):
    tools: List[Dict[str, Any]]
    auto_load: bool = True
    timeout: float | None = None
    # schema_version: str | None = None

    FIELD_SPECS = {
        "tools": ConfigFieldSpec(
            name="tools",
            display_name="Function Tool List",
            type_hint="list[FunctionToolEntryConfig]",
            required=True,
            description="Function tool list, at least one item",
            child=FunctionToolEntryConfig,
        ),
        # "auto_load": ConfigFieldSpec(
        #     name="auto_load",
        #     display_name="Auto Load Directory",
        #     type_hint="bool",
        #     required=False,
        #     default=True,
        #     description="Auto-load functions directory on startup",
        #     advance=True
        # ),
        "timeout": ConfigFieldSpec(
            name="timeout",
            display_name="Execution Timeout",
            type_hint="float",
            required=False,
            description="Tool execution timeout (seconds)",
            advance=True
        ),
        # "schema_version": ConfigFieldSpec(
        #     name="schema_version",
        #     display_name="Schema Version",
        #     type_hint="str",
        #     required=False,
        #     description="Tool schema version",
        # ),
    }

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "FunctionToolConfig":
        mapping = require_mapping(data, path)
        tools = ensure_list(mapping.get("tools"))
        if not tools:
            raise ConfigError("tools must be provided for function tooling", extend_path(path, "tools"))

        catalog = get_function_catalog()
        expanded_tools: List[Tuple[Dict[str, Any], str]] = []
        for idx, tool in enumerate(tools):
            tool_path = extend_path(path, f"tools[{idx}]")
            if not isinstance(tool, Mapping):
                raise ConfigError("tool entry must be a mapping", tool_path)
            normalized = dict(tool)
            raw_name = normalized.get("name")
            if not isinstance(raw_name, str) or not raw_name.strip():
                raise ConfigError("tool name is required", extend_path(tool_path, "name"))
            name = raw_name.strip()
            normalized["name"] = name
            module_name = cls._extract_module_from_all(name)
            if module_name:
                expanded_tools.extend(
                    cls._expand_module_all_entry(
                        module_name=module_name,
                        catalog=catalog,
                        path=tool_path,
                        original=normalized,
                    )
                )
                continue
            expanded_tools.append((normalized, tool_path))

        tool_specs: List[Dict[str, Any]] = []
        seen_functions: Dict[str, str] = {}
        for entry, entry_path in expanded_tools:
            normalized = dict(entry)
            name = normalized.get("name")
            if not isinstance(name, str) or not name.strip():
                raise ConfigError("tool name is required", extend_path(entry_path, "name"))
            metadata = catalog.get(name)
            if metadata is None:
                raise ConfigError(
                    f"function '{name}' not found under function directory",
                    extend_path(entry_path, "name"),
                )
            previous = seen_functions.get(name)
            if previous is not None:
                raise ConfigError(
                    f"function '{name}' is declared multiple times (also in {previous})",
                    extend_path(entry_path, "name"),
                )
            seen_functions[name] = entry_path

            auto_fill = normalized.get("auto_fill", True)
            if not isinstance(auto_fill, bool):
                raise ConfigError("auto_fill must be boolean", extend_path(entry_path, "auto_fill"))
            merged = dict(normalized)
            if auto_fill:
                if not merged.get("description") and metadata.description:
                    merged["description"] = metadata.description
                if not merged.get("parameters"):
                    merged["parameters"] = deepcopy(metadata.parameters_schema)
            merged.pop("auto_fill", None)
            tool_specs.append(merged)

        auto_load = optional_bool(mapping, "auto_load", path, default=True)
        timeout_value = mapping.get("timeout")
        if timeout_value is not None and not isinstance(timeout_value, (int, float)):
            raise ConfigError("timeout must be numeric", extend_path(path, "timeout"))

        # schema_version = optional_str(mapping, "schema_version", path)

        return cls(
            tools=tool_specs,
            auto_load=bool(auto_load) if auto_load is not None else True,
            timeout=float(timeout_value) if isinstance(timeout_value, (int, float)) else None,
            # schema_version=schema_version,
            path=path,
        )

    @staticmethod
    def _extract_module_from_all(value: str) -> str | None:
        if not value.endswith(MODULE_ALL_SUFFIX):
            return None
        module = value[: -len(MODULE_ALL_SUFFIX)].strip()
        return module or None

    @staticmethod
    def _expand_module_all_entry(
        *,
        module_name: str,
        catalog: FunctionCatalog,
        path: str,
        original: Mapping[str, Any],
    ) -> List[Tuple[Dict[str, Any], str]]:
        disallowed = [key for key in ("description", "parameters", "auto_fill") if key in original]
        if disallowed:
            fields = ", ".join(disallowed)
            raise ConfigError(
                f"{module_name}{MODULE_ALL_SUFFIX} does not support overriding {fields}",
                extend_path(path, "name"),
            )
        functions = catalog.functions_for_module(module_name)
        if not functions:
            raise ConfigError(
                f"module '{module_name}' has no functions under function directory",
                extend_path(path, "name"),
            )
        entries: List[Tuple[Dict[str, Any], str]] = []
        for fn_name in functions:
            entries.append(({"name": fn_name}, path))
        return entries


@dataclass
class McpRemoteConfig(BaseConfig):
    server: str
    headers: Dict[str, str] = field(default_factory=dict)
    timeout: float | None = None
    cache_ttl: float = 0.0
    tool_sources: List[str] | None = None

    FIELD_SPECS = {
        "server": ConfigFieldSpec(
            name="server",
            display_name="MCP Server URL",
            type_hint="str",
            required=True,
            description="HTTP(S) endpoint of the MCP server",
        ),
        "headers": ConfigFieldSpec(
            name="headers",
            display_name="Custom Headers",
            type_hint="dict[str, str]",
            required=False,
            description="Additional request headers (e.g. Authorization)",
            advance=True,
        ),
        "timeout": ConfigFieldSpec(
            name="timeout",
            display_name="Client Timeout",
            type_hint="float",
            required=False,
            description="Per-request timeout in seconds",
            advance=True,
        ),
        "cache_ttl": ConfigFieldSpec(
            name="cache_ttl",
            display_name="Tool Cache TTL",
            type_hint="float",
            required=False,
            description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
            advance=True,
        ),
        "tool_sources": ConfigFieldSpec(
            name="tool_sources",
            display_name="Tool Sources Filter",
            type_hint="list[str]",
            required=False,
            description="Only include MCP tools whose meta.source is in this list; omit to default to ['mcp_tools'].",
            advance=True,
        ),
    }

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "McpRemoteConfig":
        mapping = require_mapping(data, path)
        server = require_str(mapping, "server", path)

        headers_raw = mapping.get("headers")
        headers: Dict[str, str] = {}
        if headers_raw is not None:
            if not isinstance(headers_raw, Mapping):
                raise ConfigError("headers must be a mapping", extend_path(path, "headers"))
            headers = {str(k): str(v) for k, v in headers_raw.items()}

        timeout_value = mapping.get("timeout")
        timeout: float | None
        if timeout_value is None:
            timeout = None
        elif isinstance(timeout_value, (int, float)):
            timeout = float(timeout_value)
        else:
            raise ConfigError("timeout must be numeric", extend_path(path, "timeout"))

        cache_ttl_value = mapping.get("cache_ttl", 0.0)
        if cache_ttl_value is None:
            cache_ttl = 0.0
        elif isinstance(cache_ttl_value, (int, float)):
            cache_ttl = float(cache_ttl_value)
        else:
            raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))

        tool_sources_raw = mapping.get("tool_sources")
        tool_sources: List[str] | None = None
        if tool_sources_raw is not None:
            entries = ensure_list(tool_sources_raw)
            normalized: List[str] = []
            for idx, entry in enumerate(entries):
                if not isinstance(entry, str):
                    raise ConfigError(
                        "tool_sources must be a list of strings",
                        extend_path(path, f"tool_sources[{idx}]"),
                    )
                value = entry.strip()
                if value:
                    normalized.append(value)
            tool_sources = normalized
        else:
            tool_sources = ["mcp_tools"]

        return cls(
            server=server,
            headers=headers,
            timeout=timeout,
            cache_ttl=cache_ttl,
            tool_sources=tool_sources,
            path=path,
        )

    def cache_key(self) -> str:
        payload = (
            self.server,
            tuple(sorted(self.headers.items())),
            self.timeout,
        )
        return hashlib.sha1(repr(payload).encode("utf-8")).hexdigest()


@dataclass
class McpLocalConfig(BaseConfig):
    command: str
    args: List[str] = field(default_factory=list)
    cwd: str | None = None
    env: Dict[str, str] = field(default_factory=dict)
    inherit_env: bool = True
    startup_timeout: float = 10.0
    wait_for_log: str | None = None
    cache_ttl: float = 0.0

    FIELD_SPECS = {
        "command": ConfigFieldSpec(
            name="command",
            display_name="Launch Command",
            type_hint="str",
            required=True,
            description="Executable used to start the MCP stdio server (e.g. uvx)",
        ),
        "args": ConfigFieldSpec(
            name="args",
            display_name="Arguments",
            type_hint="list[str]",
            required=False,
            description="Command arguments, defaults to empty list",
        ),
        "cwd": ConfigFieldSpec(
            name="cwd",
            display_name="Working Directory",
            type_hint="str",
            required=False,
            description="Optional working directory for the launch command",
            advance=True,
        ),
        "env": ConfigFieldSpec(
            name="env",
            display_name="Environment Variables",
            type_hint="dict[str, str]",
            required=False,
            description="Additional environment variables for the process",
            advance=True,
        ),
        "inherit_env": ConfigFieldSpec(
            name="inherit_env",
            display_name="Inherit Parent Env",
            type_hint="bool",
            required=False,
            default=True,
            description="Whether to start from parent env before applying overrides",
            advance=True,
        ),
        "startup_timeout": ConfigFieldSpec(
            name="startup_timeout",
            display_name="Startup Timeout",
            type_hint="float",
            required=False,
            default=10.0,
            description="Seconds to wait for readiness logs",
            advance=True,
        ),
        "wait_for_log": ConfigFieldSpec(
            name="wait_for_log",
            display_name="Ready Log Pattern",
            type_hint="str",
            required=False,
            description="Regex that marks readiness when matched against stdout",
            advance=True,
        ),
        "cache_ttl": ConfigFieldSpec(
            name="cache_ttl",
            display_name="Tool Cache TTL",
            type_hint="float",
            required=False,
            description="Seconds to cache MCP tool list; 0 disables cache for hot updates",
            advance=True,
        ),
    }

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "McpLocalConfig":
        mapping = require_mapping(data, path)
        command = require_str(mapping, "command", path)
        args_raw = ensure_list(mapping.get("args"))
        normalized_args: List[str] = []
        for idx, arg in enumerate(args_raw):
            arg_path = extend_path(path, f"args[{idx}]")
            if not isinstance(arg, str):
                raise ConfigError("args entries must be strings", arg_path)
            normalized_args.append(arg)

        cwd = optional_str(mapping, "cwd", path)
        inherit_env = optional_bool(mapping, "inherit_env", path, default=True)
        if inherit_env is None:
            inherit_env = True

        env_mapping = mapping.get("env")
        if env_mapping is not None:
            if not isinstance(env_mapping, Mapping):
                raise ConfigError("env must be a mapping", extend_path(path, "env"))
            env = {str(k): str(v) for k, v in env_mapping.items()}
        else:
            env = {}

        timeout_value = mapping.get("startup_timeout", 10.0)
        if timeout_value is None:
            startup_timeout = 10.0
        elif isinstance(timeout_value, (int, float)):
            startup_timeout = float(timeout_value)
        else:
            raise ConfigError("startup_timeout must be numeric", extend_path(path, "startup_timeout"))

        wait_for_log = optional_str(mapping, "wait_for_log", path)
        cache_ttl_value = mapping.get("cache_ttl", 0.0)
        if cache_ttl_value is None:
            cache_ttl = 0.0
        elif isinstance(cache_ttl_value, (int, float)):
            cache_ttl = float(cache_ttl_value)
        else:
            raise ConfigError("cache_ttl must be numeric", extend_path(path, "cache_ttl"))
        return cls(
            command=command,
            args=normalized_args,
            cwd=cwd,
            env=env,
            inherit_env=bool(inherit_env),
            startup_timeout=startup_timeout,
            wait_for_log=wait_for_log,
            cache_ttl=cache_ttl,
            path=path,
        )

    def cache_key(self) -> str:
        payload = (
            self.command,
            tuple(self.args),
            self.cwd or "",
            tuple(sorted(self.env.items())),
            self.inherit_env,
            self.startup_timeout,
            self.wait_for_log or "",
        )
        return hashlib.sha1(repr(payload).encode("utf-8")).hexdigest()

register_tooling_type(
    "function",
    config_cls=FunctionToolConfig,
    description="Use local Python functions",
)
register_tooling_type(
    "mcp_remote",
    config_cls=McpRemoteConfig,
    description="Connect to an HTTP-based MCP server",
)
register_tooling_type(
    "mcp_local",
    config_cls=McpLocalConfig,
    description="Launch and connect to a local stdio MCP server",
)


@dataclass
class ToolingConfig(BaseConfig):
    type: str
    config: BaseConfig | None = None
    prefix: str | None = None

    FIELD_SPECS = {
        "type": ConfigFieldSpec(
            name="type",
            display_name="Tool Type",
            type_hint="str",
            required=True,
            description="Select a tooling adapter registered via tooling_type_registry (function, mcp_remote, mcp_local, etc.).",
        ),
        "prefix": ConfigFieldSpec(
            name="prefix",
            display_name="Tool Prefix",
            type_hint="str",
            required=False,
            description="Optional prefix for all tools from this source to prevent name collisions (e.g. 'mcp1').",
            advance=True,
        ),
        "config": ConfigFieldSpec(
            name="config",
            display_name="Tool Configuration",
            type_hint="object",
            required=True,
            description="Configuration block validated by the chosen tool type (Python function list, MCP server settings, local command MCP launch, etc.).",
        ),
    }

    @classmethod
    def child_routes(cls) -> Dict[ChildKey, type[BaseConfig]]:
        return {
            ChildKey(field="config", value=name): config_cls
            for name, config_cls in iter_tooling_type_registrations().items()
        }

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "ToolingConfig":
        mapping = require_mapping(data, path)
        tooling_type = require_str(mapping, "type", path)
        try:
            config_cls = get_tooling_type_config(tooling_type)
        except RegistryError as exc:
            raise ConfigError(
                f"tooling.type must be one of {list(iter_tooling_type_registrations().keys())}",
                extend_path(path, "type"),
            ) from exc

        config_payload = mapping.get("config")
        if config_payload is None:
            raise ConfigError("tooling requires config block", extend_path(path, "config"))

        config_obj = config_cls.from_dict(config_payload, path=extend_path(path, "config"))

        prefix = optional_str(mapping, "prefix", path)
        return cls(type=tooling_type, config=config_obj, prefix=prefix, path=path)

    @classmethod
    def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
        specs = super().field_specs()
        type_spec = specs.get("type")
        if type_spec:
            registrations = iter_tooling_type_registrations()
            metadata = iter_tooling_type_metadata()
            type_names = list(registrations.keys())
            default_value = type_names[0] if type_names else None
            descriptions = {name: (metadata.get(name) or {}).get("summary") for name in type_names}
            specs["type"] = replace(
                type_spec,
                enum=type_names,
                default=default_value,
                enum_options=enum_options_from_values(type_names, descriptions),
            )
        return specs
