import hashlib
import json
import logging
import os
import re
import time
from typing import List

from entity.configs import MemoryStoreConfig
from entity.configs.node.memory import SimpleMemoryConfig
from runtime.node.agent.memory.memory_base import (
    MemoryBase,
    MemoryContentSnapshot,
    MemoryItem,
    MemoryWritePayload,
)
import faiss
import numpy as np

logger = logging.getLogger(__name__)

class SimpleMemory(MemoryBase):
    def __init__(self, store: MemoryStoreConfig):
        config = store.as_config(SimpleMemoryConfig)
        if not config:
            raise ValueError("SimpleMemory requires a simple memory store configuration")
        super().__init__(store)
        self.config = config
        # Optimized prompt templates for clarity
        self.retrieve_prompt = "Query: {input}"
        self.update_prompt = "Input: {input}\nOutput: {output}"
        self.memory_path = self.config.memory_path  # auto
        
        # Content extraction configuration
        self.max_content_length = 500  # Maximum content length
        self.min_content_length = 20   # Minimum content length
        
    def _extract_key_content(self, content: str) -> str:
        """Extract key content while stripping redundant text."""
        # Remove redundant whitespace
        content = re.sub(r'\s+', ' ', content.strip())
        
        # Skip heavy processing for short snippets
        if len(content) <= 100:
            return content
        
        # Remove common templated instructions
        content = re.sub(r'(?:Agent|Model) Role:.*?\n\n', '', content)
        content = re.sub("(?:You are|\u4f60\u662f\u4e00\u4f4d).*?(?:,|\uff0c)", '', content)
        content = re.sub("(?:User will input|\u7528\u6237\u4f1a\u8f93\u5165).*?(?:,|\uff0c)", '', content)
        content = re.sub("(?:You need to|\u4f60\u9700\u8981).*?(?:,|\uff0c)", '', content)
        
        # Extract key sentences while skipping very short ones
        sentences = re.split(r'[\u3002\uff01\uff1f\uff1b\n]', content)
        key_sentences = [s.strip() for s in sentences if len(s.strip()) >= self.min_content_length]
        
        # Fallback to original content when no sentence survives
        if not key_sentences:
            return content[:self.max_content_length]
        
        # Recombine and limit the number of sentences (max 3)
        extracted_content = '\u3002'.join(key_sentences[:3])
        if len(extracted_content) > self.max_content_length:
            extracted_content = extracted_content[:self.max_content_length] + "..."
            
        return extracted_content.strip()
    
    def _generate_content_hash(self, content: str) -> str:
        """Generate a content hash used for deduplication."""
        return hashlib.md5(content.encode('utf-8')).hexdigest()[:8]

    def load(self) -> None:
        if self.memory_path and os.path.exists(self.memory_path) and self.memory_path.endswith(".json"):
            try:
                with open(self.memory_path) as file:
                    raw_data = json.load(file)
                contents = []
                for raw in raw_data:
                    try:
                        contents.append(MemoryItem.from_dict(raw))
                    except Exception:
                        continue
                self.contents = contents
            except Exception:
                self.contents = []

    def save(self) -> None:
        if self.memory_path and self.memory_path.endswith(".json"):
            os.makedirs(os.path.dirname(self.memory_path), exist_ok=True)
            with open(self.memory_path, "w") as file:
                json.dump([item.to_dict() for item in self.contents], file, indent=2, ensure_ascii=False)

    def retrieve(
        self,
        agent_role: str,
        query: MemoryContentSnapshot,
        top_k: int,
        similarity_threshold: float,
    ) -> List[MemoryItem]:
        if self.count_memories() == 0 or not self.embedding:
            return []
        
        # Build an optimized query for retrieval
        query_text = self.retrieve_prompt.format(input=query.text)
        query_text = self._extract_key_content(query_text)
        
        inputs_embedding = self.embedding.get_embedding(query_text)
        if isinstance(inputs_embedding, list):
            inputs_embedding = np.array(inputs_embedding, dtype=np.float32)
        inputs_embedding = inputs_embedding.reshape(1, -1)
        faiss.normalize_L2(inputs_embedding)

        expected_dim = inputs_embedding.shape[1]

        memory_embeddings = []
        valid_items = []
        for item in self.contents:
            if item.embedding is not None:
                if len(item.embedding) != expected_dim:
                    logger.warning(
                        "Skipping memory item %s: embedding dim %d != expected %d",
                        item.id, len(item.embedding), expected_dim,
                    )
                    continue
                memory_embeddings.append(item.embedding)
                valid_items.append(item)
        
        if not memory_embeddings:
            return []
            
        memory_embeddings = np.array(memory_embeddings, dtype=np.float32)

        # Use an efficient inner-product index
        index = faiss.IndexFlatIP(memory_embeddings.shape[1])
        index.add(memory_embeddings)

        # Retrieve extra candidates for reranking
        retrieval_k = min(top_k * 3, len(valid_items))
        similarities, indices = index.search(inputs_embedding, retrieval_k)
        
        # Filter and rerank the candidates
        candidates = []
        for i in range(len(indices[0])):
            idx = indices[0][i]
            similarity = similarities[0][i]

            if idx != -1 and similarity >= similarity_threshold:
                item = valid_items[idx]
                # Calculate an auxiliary semantic similarity score
                semantic_score = self._calculate_semantic_similarity(query_text, item.content_summary)
                # Combine similarity metrics
                combined_score = 0.7 * similarity + 0.3 * semantic_score
                candidates.append((item, combined_score))
        
        # Sort by the combined score and return the top_k items
        candidates.sort(key=lambda x: x[1], reverse=True)
        results = [item for item, score in candidates[:top_k]]
        
        return results
    
    def _calculate_semantic_similarity(self, query: str, content: str) -> float:
        """Compute a semantic similarity value."""
        # Enhanced semantic similarity computation
        query_lower = query.lower()
        content_lower = content.lower()
        
        # 1. Token overlap (Jaccard similarity)
        query_words = set(query_lower.split())
        content_words = set(content_lower.split())
        
        if not query_words or not content_words:
            jaccard_sim = 0.0
        else:
            intersection = query_words & content_words
            union = query_words | content_words
            jaccard_sim = len(intersection) / len(union) if union else 0.0
        
        # 2. Longest common subsequence similarity
        lcs_sim = self._calculate_lcs_similarity(query_lower, content_lower)
        
        # 3. Keyword match score
        keyword_sim = self._calculate_keyword_similarity(query_lower, content_lower)
        
        # 4. Length penalty factor (avoid overly short/long matches)
        length_factor = self._calculate_length_factor(query_lower, content_lower)
        
        # Weighted final score
        final_score = (0.4 * jaccard_sim + 
                      0.3 * lcs_sim + 
                      0.2 * keyword_sim + 
                      0.1 * length_factor)
        
        return min(final_score, 1.0)
    
    def _calculate_lcs_similarity(self, s1: str, s2: str) -> float:
        """Compute longest common subsequence similarity."""
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i-1] == s2[j-1]:
                    dp[i][j] = dp[i-1][j-1] + 1
                else:
                    dp[i][j] = max(dp[i-1][j], dp[i][j-1])
        
        lcs_length = dp[m][n]
        return lcs_length / max(len(s1), len(s2)) if max(len(s1), len(s2)) > 0 else 0.0
    
    def _calculate_keyword_similarity(self, query: str, content: str) -> float:
        """Compute keyword match similarity."""
        # Extract potential keywords (length >= 2)
        query_keywords = set(word for word in query.split() if len(word) >= 2)
        content_keywords = set(word for word in content.split() if len(word) >= 2)
        
        if not query_keywords:
            return 0.0
        
        matches = query_keywords & content_keywords
        return len(matches) / len(query_keywords)
    
    def _calculate_length_factor(self, query: str, content: str) -> float:
        """Penalize matches that deviate too much in length."""
        query_len = len(query)
        content_len = len(content)
        
        if content_len == 0:
            return 0.0
        
        # Ideal length ratio range
        ideal_ratio_min = 0.5
        ideal_ratio_max = 2.0
        
        ratio = content_len / query_len
        
        if ideal_ratio_min <= ratio <= ideal_ratio_max:
            return 1.0
        elif ratio < ideal_ratio_min:
            return ratio / ideal_ratio_min
        else:
            return max(0.1, ideal_ratio_max / ratio)

    def update(self, payload: MemoryWritePayload) -> None:
        if not self.embedding:
            return

        snapshot = payload.output_snapshot
        if not snapshot or not snapshot.text.strip():
            return

        raw_content = self.update_prompt.format(
            input=payload.inputs_text,
            output=snapshot.text,
        )
        extracted_content = self._extract_key_content(raw_content)

        if len(extracted_content) < self.min_content_length:
            return

        content_hash = self._generate_content_hash(extracted_content)
        for existing_item in self.contents:
            existing_hash = self._generate_content_hash(existing_item.content_summary)
            if existing_hash == content_hash:
                return

        embedding_vector = self.embedding.get_embedding(extracted_content)
        if isinstance(embedding_vector, list):
            embedding_vector = np.array(embedding_vector, dtype=np.float32)
        if embedding_vector is None:
            return
        embedding_array = np.array(embedding_vector, dtype=np.float32).reshape(1, -1)
        faiss.normalize_L2(embedding_array)

        metadata = {
            "agent_role": payload.agent_role,
            "input_preview": (payload.inputs_text or "")[:200],
            "content_length": len(extracted_content),
            "attachments": snapshot.attachment_overview(),
        }

        memory_item = MemoryItem(
            id=f"{content_hash}_{int(time.time())}",
            content_summary=extracted_content,
            metadata=metadata,
            embedding=embedding_array.tolist()[0],
            input_snapshot=payload.input_snapshot,
            output_snapshot=snapshot,
        )

        self.contents.append(memory_item)

        max_memories = 1000
        if len(self.contents) > max_memories:
            self.contents = self.contents[-max_memories:]
