#!/usr/bin/env python3
"""
Phase 2: Deep Research
Fetches sources, papers, data points for a slide deck topic.
Output: structured research-brief.md
"""

import argparse
import json
import os
import sys
import subprocess
import urllib.request
import urllib.parse
from datetime import datetime

WORKSPACE = os.path.expanduser("~/.openclaw/workspace")


def arxiv_search(query: str, max_results: int = 8) -> list[dict]:
    """Search arXiv for relevant papers."""
    encoded = urllib.parse.quote(query)
    url = f"https://export.arxiv.org/api/query?search_query=all:{encoded}&max_results={max_results}&sortBy=submittedDate&sortOrder=descending"
    try:
        with urllib.request.urlopen(url, timeout=15) as resp:
            content = resp.read().decode()
        import re
        entries = re.findall(r'<entry>(.*?)</entry>', content, re.DOTALL)
        results = []
        for entry in entries:
            title = re.search(r'<title>(.*?)</title>', entry, re.DOTALL)
            summary = re.search(r'<summary>(.*?)</summary>', entry, re.DOTALL)
            arxiv_id = re.search(r'<id>(.*?)</id>', entry)
            authors = re.findall(r'<name>(.*?)</name>', entry)
            published = re.search(r'<published>(.*?)</published>', entry)
            if title and summary:
                results.append({
                    "title": title.group(1).strip().replace('\n', ' '),
                    "abstract": summary.group(1).strip()[:400].replace('\n', ' '),
                    "url": arxiv_id.group(1).strip() if arxiv_id else "",
                    "authors": authors[:3],
                    "published": published.group(1)[:10] if published else "",
                })
        return results
    except Exception as e:
        print(f"arXiv search failed: {e}", file=sys.stderr)
        return []


def jina_fetch(url: str, max_chars: int = 3000) -> str:
    """Fetch a URL via Jina reader."""
    try:
        req_url = f"https://r.jina.ai/{url}"
        req = urllib.request.Request(req_url, headers={"Accept": "text/plain"})
        with urllib.request.urlopen(req, timeout=15) as resp:
            return resp.read().decode(errors='ignore')[:max_chars]
    except Exception as e:
        return f"[Fetch failed: {e}]"


def biorxiv_search(query: str) -> list[dict]:
    """Search bioRxiv for recent preprints."""
    try:
        url = f"https://api.biorxiv.org/details/biorxiv/2024-01-01/2026-03-31/{urllib.parse.quote(query)}/json"
        with urllib.request.urlopen(url, timeout=10) as resp:
            data = json.loads(resp.read())
        papers = []
        for item in data.get('collection', [])[:5]:
            papers.append({
                "title": item.get('title', ''),
                "abstract": item.get('abstract', '')[:400],
                "doi": item.get('doi', ''),
                "date": item.get('date', ''),
                "authors": item.get('authors', '').split(';')[:3],
                "url": f"https://www.biorxiv.org/content/{item.get('doi')}",
            })
        return papers
    except:
        return []


def write_brief(
    topic: str,
    audience: str,
    goal: str,
    arxiv_papers: list,
    biorxiv_papers: list,
    web_content: dict,
    output_path: str,
):
    """Write the research brief markdown."""
    now = datetime.now().strftime("%Y-%m-%d %H:%M")

    lines = [
        f"# Research Brief: {topic}",
        f"*Generated: {now}*",
        f"*Audience: {audience} | Goal: {goal}*",
        "",
        "---",
        "",
        "## Key Claims (draft — verify before using)",
        "",
        "*(Fill in after reading sources below)*",
        "",
        "1. [Claim 1 — with supporting data point]",
        "2. [Claim 2 — with supporting data point]",
        "3. [Claim 3 — with supporting data point]",
        "",
        "---",
        "",
        "## arXiv Papers",
        "",
    ]

    for i, p in enumerate(arxiv_papers, 1):
        lines += [
            f"### {i}. {p['title']}",
            f"*Authors: {', '.join(p['authors'][:3])} | {p['published']}*",
            f"*URL: {p['url']}*",
            "",
            p['abstract'],
            "",
        ]

    if biorxiv_papers:
        lines += ["---", "", "## bioRxiv Preprints", ""]
        for i, p in enumerate(biorxiv_papers, 1):
            lines += [
                f"### {i}. {p['title']}",
                f"*Date: {p['date']} | DOI: {p['doi']}*",
                f"*URL: {p['url']}*",
                "",
                p['abstract'],
                "",
            ]

    if web_content:
        lines += ["---", "", "## Web Sources", ""]
        for url, content in web_content.items():
            lines += [
                f"### Source: {url}",
                "",
                content[:800],
                "",
            ]

    lines += [
        "---",
        "",
        "## Gaps & Conflicts",
        "",
        "*(Agent: flag here where data is thin, disputed, or missing)*",
        "",
        "- [ ] [Gap 1]",
        "- [ ] [Conflict 1]",
        "",
        "---",
        "",
        "## Bo's Relevant Work",
        "",
        "- **scGPT** (Cui et al., Nature Methods 2024): Trained on 33M cells. Zero-shot cell annotation, perturbation prediction, multi-omic integration. Best results: [fill in from paper]",
        "- **LUMI-6** (Wang & Li, Cell 2026): 20.3% lung epithelial gene editing in vivo using brominated lipids",
        "- **Xaira Therapeutics**: Building causally-rich perturbation datasets for virtual cell models",
        "",
    ]

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w') as f:
        f.write('\n'.join(lines))

    print(f"✓ Research brief saved: {output_path}")
    print(f"  arXiv papers: {len(arxiv_papers)}")
    print(f"  bioRxiv papers: {len(biorxiv_papers)}")
    print(f"  Web sources: {len(web_content)}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--topic", required=True, help="Slide deck topic")
    parser.add_argument("--audience", default="scientists", help="Target audience")
    parser.add_argument("--goal", default="", help="What audience should remember")
    parser.add_argument("--brief", required=True, help="Output path for research brief")
    args = parser.parse_args()

    print(f"\n🔬 Researching: {args.topic}")
    print(f"   Audience: {args.audience}")

    # arXiv search — use multiple queries
    print("\n📚 Searching arXiv...")
    queries = [
        args.topic,
        f"{args.topic} foundation model",
        f"{args.topic} machine learning",
    ]
    seen_titles = set()
    arxiv_papers = []
    for q in queries:
        for p in arxiv_search(q, max_results=5):
            if p['title'] not in seen_titles:
                seen_titles.add(p['title'])
                arxiv_papers.append(p)
        if len(arxiv_papers) >= 10:
            break

    # bioRxiv search
    print("📋 Searching bioRxiv...")
    biorxiv_papers = biorxiv_search(args.topic)

    # Known high-value URLs for Bo's domain
    domain_urls = {}
    topic_lower = args.topic.lower()
    if any(w in topic_lower for w in ['virtual cell', 'single cell', 'scgpt', 'perturbation']):
        domain_urls = {
            "https://www.nature.com/articles/s41592-024-02201-0": "scGPT paper (Nature Methods 2024)",
            "https://chanzuckerberg.github.io/cellxgene/": "CellxGene portal",
        }

    print(f"\n✓ Found {len(arxiv_papers)} arXiv + {len(biorxiv_papers)} bioRxiv papers")

    write_brief(
        topic=args.topic,
        audience=args.audience,
        goal=args.goal,
        arxiv_papers=arxiv_papers[:8],
        biorxiv_papers=biorxiv_papers,
        web_content={},
        output_path=args.brief,
    )


if __name__ == "__main__":
    main()
