"""Hook that scans a node workspace for newly created files."""

import hashlib
import logging
import mimetypes
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple

from entity.configs import Node
from entity.messages import MessageBlockType
from utils.attachments import AttachmentRecord, AttachmentStore
from utils.human_prompt import PromptChannel


@dataclass
class WorkspaceArtifact:
    """Represents a file artifact detected by the workspace hook."""

    node_id: str
    attachment_id: str
    file_name: str
    relative_path: str
    absolute_path: str
    mime_type: Optional[str]
    size: Optional[int]
    sha256: Optional[str]
    data_uri: Optional[str]
    created_at: float
    change_type: str
    extra: Dict[str, object]


@dataclass
class _FileSignature:
    sha256: str
    size: int


@dataclass
class _TrackedEntry:
    sha256: str
    attachment_id: str
    absolute_path: str
    mime_type: Optional[str]
    size: Optional[int]
    data_uri: Optional[str]


class WorkspaceArtifactHook:
    """Detects workspace file changes for selected node types."""

    def __init__(
        self,
        *,
        attachment_store: AttachmentStore,
        emit_callback: Callable[[List[WorkspaceArtifact]], None],
        node_types: Optional[Sequence[str]] = None,
        exclude_dirs: Optional[Sequence[str]] = None,
        max_files_scanned: int = 500,
        max_bytes_scanned: int = 500 * 1024 * 1024,
        prompt_channel: Optional[PromptChannel] = None,
    ) -> None:
        self.attachment_store = attachment_store
        self.emit_callback = emit_callback
        self.node_types: Set[str] = set(node_types or {"python", "agent"})
        self.exclude_dirs = set(exclude_dirs or {"attachments", "__pycache__"})
        self.max_files_scanned = max_files_scanned
        self.max_bytes_scanned = max_bytes_scanned
        self.logger = logging.getLogger(__name__)
        self._snapshots: Dict[str, Dict[str, _FileSignature]] = {}
        self._last_emitted: Dict[str, _TrackedEntry] = {}
        self.prompt_channel = prompt_channel

    def can_handle(self, node: Node) -> bool:
        return node.node_type in self.node_types

    def get_prompt_channel(self) -> Optional[PromptChannel]:
        return self.prompt_channel

    def before_node(self, node: Node, workspace: Path) -> None:
        if not self.can_handle(node):
            return
        snapshot, _ = self._snapshot(workspace)
        self._snapshots[node.id] = snapshot

    def after_node(
        self,
        node: Node,
        workspace: Path,
        *,
        success: bool,
    ) -> None:
        if not success or not self.can_handle(node):
            self._snapshots.pop(node.id, None)
            return

        before = self._snapshots.pop(node.id, {})
        after, truncated = self._snapshot(workspace)
        if not after and not self._last_emitted:
            return

        changed_paths = [
            Path(path_str)
            for path_str, signature in after.items()
            if path_str not in before or before[path_str].sha256 != signature.sha256
        ]

        artifacts: List[WorkspaceArtifact] = []
        for relative_path in changed_paths:
            signature = after[str(relative_path)]
            full_path = workspace / relative_path
            if not full_path.exists() or not full_path.is_file():
                continue
            try:
                tracked = self._last_emitted.get(str(relative_path))
                change_type = "created" if tracked is None else "updated"
                record = self._register_artifact(
                    full_path,
                    relative_path,
                    node,
                    attachment_id=tracked.attachment_id if tracked else None,
                )
            except Exception as exc:
                self.logger.warning(
                    "Failed to register artifact %s for node %s: %s",
                    relative_path,
                    node.id,
                    exc,
                )
                continue
            artifacts.append(
                self._to_artifact(
                    record,
                    node,
                    relative_path,
                    full_path,
                    change_type=change_type,
                )
            )
            self._last_emitted[str(relative_path)] = _TrackedEntry(
                sha256=signature.sha256,
                attachment_id=record.ref.attachment_id or "",
                absolute_path=str(full_path),
                mime_type=record.ref.mime_type,
                size=record.ref.size,
                data_uri=record.ref.data_uri,
            )

        if not truncated:
            deleted_paths = [
                relative_path
                for relative_path in list(self._last_emitted.keys())
                if relative_path not in after
            ]
            for relative_path in deleted_paths:
                tracked = self._last_emitted.pop(relative_path, None)
                if not tracked:
                    continue
                artifacts.append(
                    WorkspaceArtifact(
                        node_id=node.id,
                        attachment_id=tracked.attachment_id,
                        file_name=Path(relative_path).name,
                        relative_path=relative_path,
                        absolute_path=tracked.absolute_path,
                        mime_type=tracked.mime_type,
                        size=tracked.size,
                        sha256=tracked.sha256,
                        data_uri=tracked.data_uri,
                        created_at=time.time(),
                        change_type="deleted",
                        extra={
                            "hook": "workspace_scan",
                            "relative_path": relative_path,
                        },
                    )
                )

        if artifacts:
            self.emit_callback(artifacts)

    def _snapshot(self, workspace: Path) -> Tuple[Dict[str, _FileSignature], bool]:
        entries: Dict[str, _FileSignature] = {}
        total_bytes = 0
        file_count = 0
        for root, dirs, files in os.walk(workspace):
            rel_root = Path(root).relative_to(workspace)
            dirs[:] = [d for d in dirs if not self._is_excluded(rel_root / d)]
            for filename in files:
                rel_path = rel_root / filename
                if self._is_excluded(rel_path):
                    continue
                full_path = Path(root) / filename
                try:
                    stat = full_path.stat()
                    sha256 = self._hash_file(full_path)
                except OSError:
                    continue
                file_count += 1
                total_bytes += stat.st_size
                entries[str(rel_path)] = _FileSignature(sha256=sha256, size=stat.st_size)
                if file_count >= self.max_files_scanned or total_bytes >= self.max_bytes_scanned:
                    self.logger.warning(
                        "Workspace scan truncated (files=%s total_bytes=%s) for session %s",
                        file_count,
                        total_bytes,
                    )
                    return entries, True
        return entries, False

    def _is_excluded(self, rel_path: Path) -> bool:
        if not rel_path.parts:
            return False
        return rel_path.parts[0] in self.exclude_dirs

    def _register_artifact(
        self,
        full_path: Path,
        relative_path: Path,
        node: Node,
        *,
        attachment_id: Optional[str] = None,
    ) -> AttachmentRecord:
        mime_type = mimetypes.guess_type(relative_path.name)[0] or "application/octet-stream"
        return self.attachment_store.register_file(
            full_path,
            kind=MessageBlockType.from_mime_type(mime_type),
            mime_type=mime_type,
            display_name=full_path.name,
            copy_file=False,
            persist=True,
            deduplicate=False,
            attachment_id=attachment_id,
            extra={
                "node_id": node.id,
                "relative_path": str(relative_path),
                "hook": "workspace_scan",
            },
        )

    def _to_artifact(
        self,
        record: AttachmentRecord,
        node: Node,
        relative_path: Path,
        full_path: Path,
        *,
        change_type: str,
    ) -> WorkspaceArtifact:
        ref = record.ref
        return WorkspaceArtifact(
            node_id=node.id,
            attachment_id=ref.attachment_id or "",
            file_name=ref.name or full_path.name,
            relative_path=str(relative_path),
            absolute_path=str(full_path),
            mime_type=ref.mime_type,
            size=ref.size,
            sha256=ref.sha256,
            data_uri=ref.data_uri,
            created_at=time.time(),
            change_type=change_type,
            extra=dict(record.extra),
        )

    def _hash_file(self, path: Path) -> str:
        hasher = hashlib.sha256()
        with path.open("rb") as handle:
            for chunk in iter(lambda: handle.read(1024 * 1024), b""):
                hasher.update(chunk)
        return hasher.hexdigest()
