"""Agent-specific configuration dataclasses."""

from dataclasses import dataclass, field, replace
from typing import Any, Dict, Iterable, List, Mapping, Sequence

try:  # pragma: no cover - Python < 3.11 lacks BaseExceptionGroup
    from builtins import BaseExceptionGroup as _BASE_EXCEPTION_GROUP_TYPE  # type: ignore[attr-defined]
except ImportError:  # pragma: no cover
    _BASE_EXCEPTION_GROUP_TYPE = None  # type: ignore[assignment]

from entity.enums import AgentInputMode
from schema_registry import iter_model_provider_schemas
from utils.strs import titleize

from entity.configs.base import (
    BaseConfig,
    ConfigError,
    ConfigFieldSpec,
    EnumOption,
    optional_bool,
    optional_dict,
    optional_str,
    require_mapping,
    require_str,
    extend_path,
)
from .memory import MemoryAttachmentConfig
from .skills import AgentSkillsConfig
from .thinking import ThinkingConfig
from entity.configs.node.tooling import ToolingConfig


DEFAULT_RETRYABLE_STATUS_CODES = [408, 409, 425, 429, 500, 502, 503, 504]
DEFAULT_RETRYABLE_EXCEPTION_TYPES = [
    "RateLimitError",
    "APITimeoutError",
    "APIError",
    "APIConnectionError",
    "ServiceUnavailableError",
    "TimeoutError",
    "InternalServerError",
    "RemoteProtocolError",
    "TransportError",
    "ConnectError",
    "ConnectTimeout",
    "ReadError",
    "ReadTimeout",
]
DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS = [
    "rate limit",
    "temporarily unavailable",
    "timeout",
    "server disconnected",
    "connection reset",
]


def _coerce_float(value: Any, *, field_path: str, minimum: float = 0.0) -> float:
    if isinstance(value, (int, float)):
        coerced = float(value)
    else:
        raise ConfigError("expected number", field_path)
    if coerced < minimum:
        raise ConfigError(f"value must be >= {minimum}", field_path)
    return coerced


def _coerce_positive_int(value: Any, *, field_path: str, minimum: int = 1) -> int:
    if isinstance(value, bool):
        raise ConfigError("expected integer", field_path)
    if isinstance(value, int):
        coerced = value
    else:
        raise ConfigError("expected integer", field_path)
    if coerced < minimum:
        raise ConfigError(f"value must be >= {minimum}", field_path)
    return coerced


def _coerce_str_list(value: Any, *, field_path: str) -> List[str]:
    if value is None:
        return []
    if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
        raise ConfigError("expected list of strings", field_path)
    result: List[str] = []
    for idx, item in enumerate(value):
        if not isinstance(item, str):
            raise ConfigError("expected list of strings", f"{field_path}[{idx}]")
        result.append(item.strip())
    return result


def _coerce_int_list(value: Any, *, field_path: str) -> List[int]:
    if value is None:
        return []
    if not isinstance(value, Sequence) or isinstance(value, (str, bytes)):
        raise ConfigError("expected list of integers", field_path)
    ints: List[int] = []
    for idx, item in enumerate(value):
        if isinstance(item, bool) or not isinstance(item, int):
            raise ConfigError("expected list of integers", f"{field_path}[{idx}]")
        ints.append(item)
    return ints


@dataclass
class AgentRetryConfig(BaseConfig):
    enabled: bool = True
    max_attempts: int = 5
    min_wait_seconds: float = 1.0
    max_wait_seconds: float = 6.0
    retry_on_status_codes: List[int] = field(default_factory=lambda: list(DEFAULT_RETRYABLE_STATUS_CODES))
    retry_on_exception_types: List[str] = field(default_factory=lambda: [name.lower() for name in DEFAULT_RETRYABLE_EXCEPTION_TYPES])
    non_retry_exception_types: List[str] = field(default_factory=list)
    retry_on_error_substrings: List[str] = field(default_factory=lambda: list(DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS))

    FIELD_SPECS = {
        "enabled": ConfigFieldSpec(
            name="enabled",
            display_name="Enable Retry",
            type_hint="bool",
            required=False,
            default=True,
            description="Toggle automatic retry for provider calls",
        ),
        "max_attempts": ConfigFieldSpec(
            name="max_attempts",
            display_name="Max Attempts",
            type_hint="int",
            required=False,
            default=5,
            description="Maximum number of total attempts (initial call + retries)",
        ),
        "min_wait_seconds": ConfigFieldSpec(
            name="min_wait_seconds",
            display_name="Min Wait Seconds",
            type_hint="float",
            required=False,
            default=1.0,
            description="Minimum backoff wait before retry",
            advance=True,
        ),
        "max_wait_seconds": ConfigFieldSpec(
            name="max_wait_seconds",
            display_name="Max Wait Seconds",
            type_hint="float",
            required=False,
            default=6.0,
            description="Maximum backoff wait before retry",
            advance=True,
        ),
        "retry_on_status_codes": ConfigFieldSpec(
            name="retry_on_status_codes",
            display_name="Retryable Status Codes",
            type_hint="list[int]",
            required=False,
            description="HTTP status codes that should trigger a retry",
            advance=True,
        ),
        "retry_on_exception_types": ConfigFieldSpec(
            name="retry_on_exception_types",
            display_name="Retryable Exception Types",
            type_hint="list[str]",
            required=False,
            description="Exception class names (case-insensitive) that should trigger retries",
            advance=True,
        ),
        "non_retry_exception_types": ConfigFieldSpec(
            name="non_retry_exception_types",
            display_name="Non-Retryable Exception Types",
            type_hint="list[str]",
            required=False,
            description="Exception class names (case-insensitive) that should never retry",
            advance=True,
        ),
        "retry_on_error_substrings": ConfigFieldSpec(
            name="retry_on_error_substrings",
            display_name="Retryable Message Substrings",
            type_hint="list[str]",
            required=False,
            description="Substring matches within exception messages that enable retry",
            advance=True,
        ),
    }

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "AgentRetryConfig":
        mapping = require_mapping(data, path)
        enabled = optional_bool(mapping, "enabled", path, default=True)
        if enabled is None:
            enabled = True
        max_attempts = _coerce_positive_int(mapping.get("max_attempts", 5), field_path=extend_path(path, "max_attempts"))
        min_wait = _coerce_float(mapping.get("min_wait_seconds", 1.0), field_path=extend_path(path, "min_wait_seconds"), minimum=0.0)
        max_wait = _coerce_float(mapping.get("max_wait_seconds", 6.0), field_path=extend_path(path, "max_wait_seconds"), minimum=0.0)
        if max_wait < min_wait:
            raise ConfigError("max_wait_seconds must be >= min_wait_seconds", extend_path(path, "max_wait_seconds"))

        status_codes = mapping.get("retry_on_status_codes")
        if status_codes is None:
            retry_status_codes = list(DEFAULT_RETRYABLE_STATUS_CODES)
        else:
            retry_status_codes = _coerce_int_list(status_codes, field_path=extend_path(path, "retry_on_status_codes"))

        retry_types_raw = mapping.get("retry_on_exception_types")
        if retry_types_raw is None:
            retry_types = [name.lower() for name in DEFAULT_RETRYABLE_EXCEPTION_TYPES]
        else:
            retry_types = [value.lower() for value in _coerce_str_list(retry_types_raw, field_path=extend_path(path, "retry_on_exception_types")) if value]

        non_retry_types = [value.lower() for value in _coerce_str_list(mapping.get("non_retry_exception_types"), field_path=extend_path(path, "non_retry_exception_types")) if value]

        retry_substrings_raw = mapping.get("retry_on_error_substrings")
        if retry_substrings_raw is None:
            retry_substrings = list(DEFAULT_RETRYABLE_MESSAGE_SUBSTRINGS)
        else:
            retry_substrings = [
                value.lower()
                for value in _coerce_str_list(
                    retry_substrings_raw,
                    field_path=extend_path(path, "retry_on_error_substrings"),
                )
                if value
            ]

        return cls(
            enabled=enabled,
            max_attempts=max_attempts,
            min_wait_seconds=min_wait,
            max_wait_seconds=max_wait,
            retry_on_status_codes=retry_status_codes,
            retry_on_exception_types=retry_types,
            non_retry_exception_types=non_retry_types,
            retry_on_error_substrings=retry_substrings,
            path=path,
        )

    @property
    def is_active(self) -> bool:
        return self.enabled and self.max_attempts > 1

    def should_retry(self, exc: BaseException) -> bool:
        if not self.is_active:
            return False

        chain: List[tuple[BaseException, set[str], int | None, str]] = []
        for error in self._iter_exception_chain(exc):
            chain.append(
                (
                    error,
                    self._exception_name_set(error),
                    self._extract_status_code(error),
                    str(error).lower(),
                )
            )

        if self.non_retry_exception_types:
            for _, names, _, _ in chain:
                if any(name in names for name in self.non_retry_exception_types):
                    return False

        if self.retry_on_exception_types:
            for _, names, _, _ in chain:
                if any(name in names for name in self.retry_on_exception_types):
                    return True

        if self.retry_on_status_codes:
            for _, _, status_code, _ in chain:
                if status_code is not None and status_code in self.retry_on_status_codes:
                    return True

        if self.retry_on_error_substrings:
            for _, _, _, message in chain:
                if message and any(substr in message for substr in self.retry_on_error_substrings):
                    return True

        return False

    def _exception_name_set(self, exc: BaseException) -> set[str]:
        names: set[str] = set()
        for cls in exc.__class__.mro():
            names.add(cls.__name__.lower())
            names.add(f"{cls.__module__}.{cls.__name__}".lower())
        return names

    def _extract_status_code(self, exc: BaseException) -> int | None:
        for attr in ("status_code", "http_status", "status", "statusCode"):
            value = getattr(exc, attr, None)
            if isinstance(value, int):
                return value
        response = getattr(exc, "response", None)
        if response is not None:
            for attr in ("status_code", "status", "statusCode"):
                value = getattr(response, attr, None)
                if isinstance(value, int):
                    return value
        return None

    def _iter_exception_chain(self, exc: BaseException) -> Iterable[BaseException]:
        seen: set[int] = set()
        stack: List[BaseException] = [exc]
        while stack:
            current = stack.pop()
            if id(current) in seen:
                continue
            seen.add(id(current))
            yield current

            linked: List[BaseException] = []
            cause = getattr(current, "__cause__", None)
            context = getattr(current, "__context__", None)
            if isinstance(cause, BaseException):
                linked.append(cause)
            if isinstance(context, BaseException):
                linked.append(context)
            if _BASE_EXCEPTION_GROUP_TYPE is not None and isinstance(current, _BASE_EXCEPTION_GROUP_TYPE):
                for exc_item in getattr(current, "exceptions", None) or ():
                    if isinstance(exc_item, BaseException):
                        linked.append(exc_item)
            stack.extend(linked)


@dataclass
class AgentConfig(BaseConfig):
    provider: str
    base_url: str
    name: str
    role: str | None = None
    api_key: str | None = None
    params: Dict[str, Any] = field(default_factory=dict)
    retry: AgentRetryConfig | None = None
    input_mode: AgentInputMode = AgentInputMode.MESSAGES
    tooling: List[ToolingConfig] = field(default_factory=list)
    thinking: ThinkingConfig | None = None
    memories: List[MemoryAttachmentConfig] = field(default_factory=list)
    skills: AgentSkillsConfig | None = None

    # Runtime attributes (attached dynamically)
    token_tracker: Any | None = field(default=None, init=False, repr=False)
    node_id: str | None = field(default=None, init=False, repr=False)

    @classmethod
    def from_dict(cls, data: Mapping[str, Any], *, path: str) -> "AgentConfig":
        mapping = require_mapping(data, path)
        provider = require_str(mapping, "provider", path)
        base_url = optional_str(mapping, "base_url", path)
        name_value = mapping.get("name")
        if isinstance(name_value, str) and name_value.strip():
            model_name = name_value.strip()
        else:
            raise ConfigError("model.name must be a non-empty string", extend_path(path, "name"))

        role = optional_str(mapping, "role", path)
        api_key = optional_str(mapping, "api_key", path)
        params = optional_dict(mapping, "params", path) or {}
        raw_input_mode = optional_str(mapping, "input_mode", path)
        input_mode = AgentInputMode.MESSAGES
        if raw_input_mode:
            try:
                input_mode = AgentInputMode(raw_input_mode.strip().lower())
            except ValueError as exc:
                raise ConfigError(
                    "model.input_mode must be 'prompt' or 'messages'",
                    extend_path(path, "input_mode"),
                ) from exc

        tooling_cfg: List[ToolingConfig] = []
        if "tooling" in mapping and mapping["tooling"] is not None:
            raw_tooling = mapping["tooling"]
            if not isinstance(raw_tooling, list):
                 raise ConfigError("tooling must be a list", extend_path(path, "tooling"))
            for idx, item in enumerate(raw_tooling):
                tooling_cfg.append(
                    ToolingConfig.from_dict(item, path=extend_path(path, f"tooling[{idx}]"))
                )

        thinking_cfg = None
        if "thinking" in mapping and mapping["thinking"] is not None:
            thinking_cfg = ThinkingConfig.from_dict(mapping["thinking"], path=extend_path(path, "thinking"))

        memories_cfg: List[MemoryAttachmentConfig] = []
        if "memories" in mapping and mapping["memories"] is not None:
            raw_memories = mapping["memories"]
            if not isinstance(raw_memories, list):
                raise ConfigError("memories must be a list", extend_path(path, "memories"))
            for idx, item in enumerate(raw_memories):
                memories_cfg.append(
                    MemoryAttachmentConfig.from_dict(item, path=extend_path(path, f"memories[{idx}]"))
                )

        retry_cfg = None
        if "retry" in mapping and mapping["retry"] is not None:
            retry_cfg = AgentRetryConfig.from_dict(mapping["retry"], path=extend_path(path, "retry"))

        skills_cfg = None
        if "skills" in mapping and mapping["skills"] is not None:
            skills_cfg = AgentSkillsConfig.from_dict(mapping["skills"], path=extend_path(path, "skills"))

        return cls(
            provider=provider,
            base_url=base_url,
            name=model_name,
            role=role,
            api_key=api_key,
            params=params,
            tooling=tooling_cfg,
            thinking=thinking_cfg,
            memories=memories_cfg,
            skills=skills_cfg,
            retry=retry_cfg,
            input_mode=input_mode,
            path=path,
        )

    FIELD_SPECS = {
        "name": ConfigFieldSpec(
            name="name",
            display_name="Model Name",
            type_hint="str",
            required=True,
            description="Specific model name e.g. gpt-4o",
        ),
        "role": ConfigFieldSpec(
            name="role",
            display_name="System Prompt",
            type_hint="text",
            required=False,
            description="Model system prompt",
        ),
        "provider": ConfigFieldSpec(
            name="provider",
            display_name="Model Provider",
            type_hint="str",
            required=True,
            description="Name of a registered provider (openai, gemini, etc.) that selects the underlying client adapter.",
            default="openai",
        ),
        "base_url": ConfigFieldSpec(
            name="base_url",
            display_name="Base URL",
            type_hint="str",
            required=False,
            description="Override the provider's default endpoint; leave empty to use the built-in base URL.",
            advance=True,
            default="${BASE_URL}",
        ),
        "api_key": ConfigFieldSpec(
            name="api_key",
            display_name="API Key",
            type_hint="str",
            required=False,
            description="Credential consumed by the provider client; reference an env var such as ${API_KEY} that matches the selected provider.",
            advance=True,
            default="${API_KEY}",
        ),
        "params": ConfigFieldSpec(
            name="params",
            display_name="Call Parameters",
            type_hint="dict[str, Any]",
            required=False,
            default={},
            description="Call parameters (temperature, top_p, etc.)",
            advance=True,
        ),
        # "input_mode": ConfigFieldSpec(  # currently, many features depend on messages mode, so hide this and force messages
        #     name="input_mode",
        #     display_name="Input Mode",
        #     type_hint="enum:AgentInputMode",
        #     required=False,
        #     default=AgentInputMode.MESSAGES.value,
        #     description="Model input mode: messages (default) or prompt",
        #     enum=[item.value for item in AgentInputMode],
        #     advance=True,
        #     enum_options=enum_options_for(AgentInputMode),
        # ),
        "tooling": ConfigFieldSpec(
            name="tooling",
            display_name="Tool Configuration",
            type_hint="list[ToolingConfig]",
            required=False,
            description="Bound tool configuration list",
            child=ToolingConfig,
            advance=True,
        ),
        "thinking": ConfigFieldSpec(
            name="thinking",
            display_name="Thinking Configuration",
            type_hint="ThinkingConfig",
            required=False,
            description="Thinking process configuration",
            child=ThinkingConfig,
            advance=True,
        ),
        "memories": ConfigFieldSpec(
            name="memories",
            display_name="Memory Attachments",
            type_hint="list[MemoryAttachmentConfig]",
            required=False,
            description="Associated memory references",
            child=MemoryAttachmentConfig,
            advance=True,
        ),
        "skills": ConfigFieldSpec(
            name="skills",
            display_name="Agent Skills",
            type_hint="AgentSkillsConfig",
            required=False,
            description="Agent Skills allowlist and built-in skill activation/file-read tools.",
            child=AgentSkillsConfig,
            advance=True,
        ),
        "retry": ConfigFieldSpec(
            name="retry",
            display_name="Retry Policy",
            type_hint="AgentRetryConfig",
            required=False,
            description="Automatic retry policy for this model",
            child=AgentRetryConfig,
            advance=True,
        ),
    }

    @classmethod
    def field_specs(cls) -> Dict[str, ConfigFieldSpec]:
        specs = super().field_specs()
        provider_spec = specs.get("provider")
        if provider_spec:
            enum_spec = cls._apply_provider_enum(provider_spec)
            specs["provider"] = enum_spec
        return specs

    @staticmethod
    def _apply_provider_enum(provider_spec: ConfigFieldSpec) -> ConfigFieldSpec:
        provider_names, metadata = AgentConfig._provider_registry_snapshot()
        if not provider_names:
            return provider_spec

        enum_options: List[EnumOption] = []
        for name in provider_names:
            meta = metadata.get(name) or {}
            label = meta.get("label") or titleize(name)
            enum_options.append(
                EnumOption(
                    value=name,
                    label=label,
                    description=meta.get("summary"),
                )
            )

        default_value = provider_spec.default
        if not default_value or default_value not in provider_names:
            default_value = AgentConfig._preferred_provider_default(provider_names)

        return replace(
            provider_spec,
            enum=provider_names,
            enum_options=enum_options,
            default=default_value,
        )

    @staticmethod
    def _preferred_provider_default(provider_names: List[str]) -> str:
        if "openai" in provider_names:
            return "openai"
        return provider_names[0]

    @staticmethod
    def _provider_registry_snapshot() -> tuple[List[str], Dict[str, Dict[str, Any]]]:
        specs = iter_model_provider_schemas()
        names = list(specs.keys())
        metadata: Dict[str, Dict[str, Any]] = {}
        for name, spec in specs.items():
            metadata[name] = {
                "label": spec.label,
                "summary": spec.summary,
                **(spec.metadata or {}),
            }
        return names, metadata
