#!/usr/bin/env python3
"""
arxiv_daily.py - Daily arXiv scanner for Bo Wang
Fetches top 5 papers from AI, biology, medicine categories (abstracts only, no PDFs)
"""

import urllib.request
import urllib.parse
import xml.etree.ElementTree as ET
import re
import sys
from datetime import datetime, timedelta, timezone

# ── Interest scoring ──────────────────────────────────────────────────────────

HIGH_PRIORITY = [
    r"single.cell", r"scrna.?seq", r"scatac", r"spatial transcriptomics",
    r"cell atlas", r"drug discovery", r"drug design", r"molecular generation",
    r"protein structure", r"foundation model", r"virtual cell", r"perturbation",
    r"gene regulatory network", r"cell type annotation", r"multiom", r"multi.om",
    r"scgpt", r"single.cell foundation", r"cancer genomics", r"epigenomics",
    r"crispr", r"rna.seq", r"trajectory inference",
]

MEDIUM_PRIORITY = [
    r"large language model", r"\bllm\b", r"transformer", r"reinforcement learning",
    r"\bagent\b", r"multi.agent", r"generative model", r"diffusion model",
    r"genomics", r"proteomics", r"biomarker", r"clinical", r"medical imaging",
    r"federated learning", r"zero.shot", r"few.shot", r"self.supervised",
    r"representation learning", r"neural network", r"deep learning",
    r"drug target", r"molecular docking", r"sequence model",
]

CATEGORIES = [
    "cs.LG", "cs.AI", "cs.CL", "cs.NE",
    "q-bio.GN", "q-bio.QM", "q-bio.BM", "q-bio.MN", "q-bio.CB",
    "stat.ML",
]

# ── Fetch ─────────────────────────────────────────────────────────────────────

def fetch_papers_huggingface(max_results=100):
    """Fetch trending arXiv papers from HuggingFace Daily Papers API."""
    import json
    from datetime import date, timedelta as td
    papers = []
    seen = set()
    # Try today and yesterday
    for delta in range(3):
        day = (date.today() - td(days=delta)).strftime("%Y-%m-%d")
        url = f"https://huggingface.co/api/daily_papers?date={day}"
        req = urllib.request.Request(url, headers={"User-Agent": "arxiv-daily-bot/1.0"})
        try:
            with urllib.request.urlopen(req, timeout=15) as resp:
                items = json.loads(resp.read().decode("utf-8"))
            for item in items:
                p = item.get("paper", {})
                pid = p.get("id", "")
                if not pid or pid in seen:
                    continue
                seen.add(pid)
                authors_raw = p.get("authors") or []
                authors = [a.get("name", "") for a in authors_raw[:4]]
                papers.append({
                    "title":      (p.get("title") or "").strip(),
                    "abstract":   (p.get("summary") or "").strip(),
                    "url":        f"https://arxiv.org/abs/{pid}",
                    "published":  day + "T00:00:00+00:00",
                    "authors":    authors,
                    "categories": [],
                    "upvotes":    p.get("upvotes", 0),
                })
        except Exception:
            pass
    return papers


def fetch_papers(max_results=100):
    """Fetch papers — HuggingFace Daily Papers as primary source."""
    papers = fetch_papers_huggingface(max_results)
    if papers:
        return papers
    return None

# ── Parse ─────────────────────────────────────────────────────────────────────

def parse_papers(xml_data):
    ns_atom = "http://www.w3.org/2005/Atom"
    ns_arxiv = "http://arxiv.org/schemas/atom"
    root = ET.fromstring(xml_data)
    papers = []
    for entry in root.findall(f"{{{ns_atom}}}entry"):
        def t(tag, ns=ns_atom):
            el = entry.find(f"{{{ns}}}{tag}")
            return el.text.strip() if el is not None and el.text else ""

        title    = re.sub(r"\s+", " ", t("title"))
        abstract = re.sub(r"\s+", " ", t("summary"))
        url      = t("id")
        published= t("published")
        authors  = [
            a.find(f"{{{ns_atom}}}name").text
            for a in entry.findall(f"{{{ns_atom}}}author")
            if a.find(f"{{{ns_atom}}}name") is not None
        ]
        # categories
        cats = set()
        pc = entry.find(f"{{{ns_arxiv}}}primary_category")
        if pc is not None:
            cats.add(pc.get("term", ""))
        for c in entry.findall(f"{{{ns_atom}}}category"):
            cats.add(c.get("term", ""))

        papers.append({
            "title":      title,
            "abstract":   abstract,
            "url":        url.replace("http://", "https://"),
            "published":  published,
            "authors":    authors[:4],
            "categories": sorted(cats - {""}),
        })
    return papers

# ── Score ─────────────────────────────────────────────────────────────────────

def score(paper):
    text = (paper["title"] + " " + paper["abstract"]).lower()
    s = 0
    for kw in HIGH_PRIORITY:
        if re.search(kw, text):
            s += 3
    for kw in MEDIUM_PRIORITY:
        if re.search(kw, text):
            s += 1
    return s

# ── Format ────────────────────────────────────────────────────────────────────

def format_report(papers, mode="telegram"):
    today = datetime.now().strftime("%B %d, %Y")
    if mode == "telegram":
        lines = [f"📚 *arXiv Daily Scout — {today}*\n"]
        for i, p in enumerate(papers, 1):
            authors = ", ".join(p["authors"][:2])
            if len(p["authors"]) > 2:
                authors += " et al."
            snippet = p["abstract"][:280].rstrip() + "…"
            cats = " ".join(f"`{c}`" for c in p["categories"][:3])
            lines.append(
                f"*{i}. {p['title']}*\n"
                f"_{authors}_  {cats}\n"
                f"{snippet}\n"
                f"🔗 {p['url']}\n"
            )
        return "\n".join(lines)
    else:  # email / plain
        lines = [f"arXiv Daily Scout — {today}\n{'='*50}\n"]
        for i, p in enumerate(papers, 1):
            authors = ", ".join(p["authors"][:2])
            if len(p["authors"]) > 2:
                authors += " et al."
            snippet = p["abstract"][:400].rstrip() + "…"
            cats = ", ".join(p["categories"][:3])
            lines.append(
                f"{i}. {p['title']}\n"
                f"   {authors} [{cats}]\n"
                f"   {snippet}\n"
                f"   {p['url']}\n"
            )
        return "\n".join(lines)

# ── Main ──────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    mode = sys.argv[1] if len(sys.argv) > 1 else "telegram"

    result = fetch_papers(max_results=100)
    if not result:
        print("⚠️ All paper sources failed or returned no results.", file=sys.stderr)
        sys.exit(1)
    papers = result

    # Score
    for p in papers:
        p["score"] = score(p)

    # Prefer papers submitted in the last 2 days; fall back to all if sparse
    cutoff = datetime.now(timezone.utc) - timedelta(days=2)
    recent = [
        p for p in papers
        if datetime.fromisoformat(p["published"].replace("Z", "+00:00")) > cutoff
    ]
    pool = recent if len(recent) >= 5 else papers

    top5 = sorted(pool, key=lambda x: (x["score"], x.get("upvotes", 0)), reverse=True)[:5]
    print(format_report(top5, mode=mode))
