#!/usr/bin/env python3
"""
PPO-style RL locomotion demo video.
4 panels showing a bipedal walker at different training stages.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Circle
import matplotlib.patheffects as pe

# ── Video settings ──────────────────────────────────────────
FPS = 30
DURATION = 36          # seconds
N_FRAMES  = FPS * DURATION
DPI = 120
FIGSIZE = (12, 7)

# 4 panels: training stages
STAGES = [
    dict(label="Episode 1",    quality_start=0.00, quality_end=0.05, color="#E74C3C"),
    dict(label="Episode 50",   quality_start=0.20, quality_end=0.45, color="#F39C12"),
    dict(label="Episode 200",  quality_start=0.60, quality_end=0.82, color="#27AE60"),
    dict(label="Episode 500",  quality_start=0.90, quality_end=0.99, color="#3498DB"),
]

# ── Walker kinematics ────────────────────────────────────────
THIGH_L = 0.38
SHIN_L  = 0.35
BODY_H  = 0.48
HEAD_R  = 0.11

def walker_joints(t, quality):
    freq  = 1.3
    phase = 2 * np.pi * freq * t

    hip_y = 0.85 + 0.04 * quality * np.sin(2 * phase)

    def good_leg(ph):
        thigh = 0.50 * np.sin(ph)
        shin  = thigh - 0.25 * np.sin(ph + 0.5) + 0.28
        return thigh, shin

    lt, ls = good_leg(phase)
    rt, rs = good_leg(phase + np.pi)

    # Noise for early training
    if quality < 1.0:
        n = 1.0 - quality
        lt += n * 0.9 * np.sin(t * 4.7 + 0.3)
        ls += n * 0.8 * np.sin(t * 7.9 + 1.1)
        rt += n * 0.9 * np.sin(t * 5.3 + 2.0)
        rs += n * 0.8 * np.sin(t * 8.6 + 0.8)
        hip_y = max(0.45, hip_y - n * 0.25 * abs(np.sin(t * 3.7)))

    def leg(th_a, sh_a, hx, hy):
        kx = hx + THIGH_L * np.sin(th_a)
        ky = hy - THIGH_L * np.cos(th_a)
        fa = th_a + sh_a
        fx = kx + SHIN_L  * np.sin(fa)
        fy = ky - SHIN_L  * np.cos(fa)
        fy = max(0.0, fy)
        return (kx, ky), (fx, fy)

    hx = 0.0
    lk, lf = leg(lt, ls, hx, hip_y)
    rk, rf = leg(rt, rs, hx, hip_y)

    lean = 0.12 * quality
    hd_x = hx - BODY_H * np.sin(lean)
    hd_y = hip_y + BODY_H * np.cos(lean)

    return dict(head=(hd_x, hd_y), hip=(hx, hip_y),
                l_knee=lk, r_knee=rk, l_foot=lf, r_foot=rf)

def draw_walker(ax, t, quality, x_scroll, color):
    j = walker_joints(t, quality)

    def tx(x): return x - x_scroll
    def p(key): return (tx(j[key][0]), j[key][1])

    # Ground line (subtle)
    ax.axhline(0, color='#555', lw=1, zorder=0)

    hip   = p('hip')
    head  = p('head')
    lknee = p('l_knee')
    rknee = p('r_knee')
    lfoot = p('l_foot')
    rfoot = p('r_foot')

    lw = 4
    js = 7   # joint dot size

    # Back leg (slightly dimmer)
    back_color = tuple(c * 0.55 for c in plt.matplotlib.colors.to_rgb(color)) + (1,)
    ax.plot([hip[0], rknee[0]], [hip[1], rknee[1]], color=back_color, lw=lw, solid_capstyle='round', zorder=2)
    ax.plot([rknee[0], rfoot[0]], [rknee[1], rfoot[1]], color=back_color, lw=lw, solid_capstyle='round', zorder=2)

    # Torso
    ax.plot([hip[0], head[0]], [hip[1], head[1]], color=color, lw=lw+1, solid_capstyle='round', zorder=4)

    # Front leg
    ax.plot([hip[0], lknee[0]], [hip[1], lknee[1]], color=color, lw=lw, solid_capstyle='round', zorder=4)
    ax.plot([lknee[0], lfoot[0]], [lknee[1], lfoot[1]], color=color, lw=lw, solid_capstyle='round', zorder=4)

    # Joints
    for pt in [hip, lknee, rknee]:
        ax.plot(pt[0], pt[1], 'o', color='white', ms=js, zorder=5)
        ax.plot(pt[0], pt[1], 'o', color=color,   ms=js-3, zorder=6)

    # Head
    head_circ = Circle(head, HEAD_R, color=color, zorder=6)
    ax.add_patch(head_circ)
    ax.plot(head[0], head[1], 'o', color='white', ms=5, zorder=7)


# ── Reward curve helpers ─────────────────────────────────────
def reward_curve(stage, t_norm):
    """Smooth reward curve from ~quality_start → quality_end."""
    q0, q1 = stage['quality_start'], stage['quality_end']
    # Sigmoid-like growth
    x = t_norm * 6 - 3
    sig = 1 / (1 + np.exp(-x))
    return q0 + (q1 - q0) * sig


# ── Build figure ─────────────────────────────────────────────
fig = plt.figure(figsize=FIGSIZE, facecolor='#1a1a2e')
fig.suptitle("Proximal Policy Optimization  ·  Walker2D", 
             color='white', fontsize=14, fontweight='bold', y=0.97,
             fontfamily='monospace')

axes = []
for i, stage in enumerate(STAGES):
    ax = fig.add_subplot(2, 2, i + 1)
    ax.set_facecolor('#0f0f23')
    ax.set_xlim(-2.2, 2.2)
    ax.set_ylim(-0.15, 2.0)
    ax.set_aspect('equal')
    ax.set_xticks([]); ax.set_yticks([])
    for sp in ax.spines.values():
        sp.set_color('#333355')
        sp.set_linewidth(0.8)
    ax.set_title(stage['label'], color=stage['color'],
                 fontsize=10, fontfamily='monospace', pad=4)
    axes.append(ax)

# Small reward bars at bottom of each panel
reward_axes = []
for i, stage in enumerate(STAGES):
    row = i // 2
    col = i % 2
    rax = fig.add_axes([0.08 + col * 0.50, 0.04 + row * 0.465, 0.38, 0.028])
    rax.set_facecolor('#0a0a1a')
    rax.set_xlim(0, 1); rax.set_ylim(0, 1)
    rax.set_xticks([]); rax.set_yticks([])
    for sp in rax.spines.values(): sp.set_color('#333355')
    reward_axes.append(rax)

plt.tight_layout(rect=[0, 0.07, 1, 0.95])

# Per-panel state
x_scrolls = [0.0] * 4
artists_per_panel = [[] for _ in range(4)]
reward_bars = []
reward_texts = []

for i, (ax, stage) in enumerate(zip(axes, STAGES)):
    bar, = reward_axes[i].bar([0.5], [0.0], width=0.98, color=stage['color'],
                               alpha=0.8, align='center', bottom=0)
    reward_bars.append(bar)
    txt = reward_axes[i].text(0.5, 0.5, "Reward: 0", color='white',
                               ha='center', va='center', fontsize=7,
                               fontfamily='monospace')
    reward_texts.append(txt)

# Step counter text
step_text = fig.text(0.5, 0.005, "Training step: 0", color='#aaaacc',
                     ha='center', fontsize=9, fontfamily='monospace')


def animate(frame):
    t_norm = frame / N_FRAMES
    t      = frame / FPS

    for i, (ax, stage) in enumerate(zip(axes, STAGES)):
        ax.cla()
        ax.set_facecolor('#0f0f23')
        ax.set_xlim(-2.2, 2.2)
        ax.set_ylim(-0.15, 2.0)
        ax.set_aspect('equal')
        ax.set_xticks([]); ax.set_yticks([])
        for sp in ax.spines.values():
            sp.set_color('#333355'); sp.set_linewidth(0.8)
        ax.set_title(stage['label'], color=stage['color'],
                     fontsize=10, fontfamily='monospace', pad=4)

        quality = reward_curve(stage, t_norm)

        # Scroll forward proportional to quality
        walk_speed = quality * 1.1
        x_scrolls[i] += walk_speed / FPS
        x_s = x_scrolls[i]

        # Draw subtle grid lines (ground tiles)
        for gx in np.arange(-10, 20, 0.8):
            tx = gx - (x_s % 0.8)
            if -2.5 < tx < 2.5:
                ax.axvline(tx, color='#1e1e3a', lw=0.5, zorder=0)

        draw_walker(ax, t, quality, x_s, stage['color'])

        # Reward bar
        r = quality
        reward_bars[i].set_height(r)
        reward_bars[i].set_y(0)
        reward_texts[i].set_text(f"Reward: {r:.2f}")

    total_steps = int(t_norm * 500_000)
    step_text.set_text(f"Training step: {total_steps:,}")

    return []


ani = animation.FuncAnimation(fig, animate, frames=N_FRAMES,
                               interval=1000 / FPS, blit=False)

out = '/Users/bowang/.openclaw/workspace/ppo_walker_demo.mp4'
writer = animation.FFMpegWriter(fps=FPS, bitrate=2500,
                                 extra_args=['-vcodec', 'libx264', '-pix_fmt', 'yuv420p'])
print(f"Rendering {N_FRAMES} frames at {FPS}fps → {out}")
ani.save(out, writer=writer, dpi=DPI)
print("Done!")
