# qhca_neural_io_stub.py
# Neuralink-style I/O stub for phase estimation -> pseudo-hologram reconstruction.
# See header comments for usage.

import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple

try:
    from scipy.signal import hilbert, butter, sosfiltfilt, detrend
    SCIPY_OK = True
except Exception:
    SCIPY_OK = False

os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

def bandpass(data: np.ndarray, fs: float, low: Optional[float], high: Optional[float]) -> np.ndarray:
    if not SCIPY_OK or (low is None and high is None):
        return data
    if low is None:
        low = 0.0
    if high is None:
        high = fs/2.0 - 1.0
    from scipy.signal import butter, sosfiltfilt
    sos = butter(4, [low/(fs/2.0), high/(fs/2.0)], btype='bandpass', output='sos')
    return sosfiltfilt(sos, data, axis=0)

def load_lfp_csv(path: str) -> Tuple[np.ndarray, np.ndarray]:
    arr = np.genfromtxt(path, delimiter=',', names=True)
    time = arr[arr.dtype.names[0]]
    channels = np.column_stack([arr[name] for name in arr.dtype.names[1:]])
    return time, channels

def load_spikes_csv(path: str, duration: float, fs: float, channels: int) -> Tuple[np.ndarray, np.ndarray]:
    raw = np.genfromtxt(path, delimiter=',', names=True)
    unit = raw[raw.dtype.names[0]].astype(int)
    ts   = raw[raw.dtype.names[1]].astype(float)
    t = np.arange(0, duration, 1.0/fs)
    rates = np.zeros((t.size, channels), dtype=float)
    bin_edges = np.append(t, t[-1] + 1.0/fs)
    for ch in range(channels):
        idx = unit == ch
        hist, _ = np.histogram(ts[idx], bins=bin_edges)
        rates[:, ch] = hist * fs  # spikes/s
    return t, rates

def preprocess_signals(x: np.ndarray, fs: float, low: Optional[float], high: Optional[float]) -> np.ndarray:
    if SCIPY_OK:
        from scipy.signal import detrend
        x = detrend(x, axis=0, type='linear')
    x = bandpass(x, fs, low, high)
    std = np.std(x, axis=0) + 1e-9
    x = (x - np.mean(x, axis=0)) / std
    return x

def estimate_phase(x: np.ndarray):
    if not SCIPY_OK:
        X = np.fft.rfft(x, axis=0)
        N = x.shape[0]
        H = np.zeros_like(X, dtype=float)
        H[1:-1] = 2.0
        analytic = np.fft.irfft(X * H[:, None], n=N, axis=0)
    else:
        from scipy.signal import hilbert
        analytic = hilbert(x, axis=0)
    amp = np.abs(analytic)
    phase = np.angle(analytic)
    return amp, phase

def channels_to_grid(phase: np.ndarray, grid_side: int) -> np.ndarray:
    T, C = phase.shape
    S = grid_side
    total = S * S
    out = np.zeros((T, S, S), dtype=float)
    for t in range(T):
        vec = np.zeros(total, dtype=float)
        L = min(C, total)
        vec[:L] = phase[t, :L]
        out[t] = vec.reshape(S, S)
    return out

def pseudo_hologram_intensity(phase_grid: np.ndarray) -> np.ndarray:
    T, S, _ = phase_grid.shape
    inten = np.zeros_like(phase_grid)
    for t in range(T):
        field = np.exp(1j * phase_grid[t])
        F = np.fft.fft2(field)
        inten[t] = np.abs(F) ** 2
    return inten

def preview_plots(time: np.ndarray, x: np.ndarray, phase_grid: np.ndarray, holo: np.ndarray, idx: int = -1):
    import matplotlib.pyplot as plt
    plt.figure()
    k = min(4, x.shape[1])
    for i in range(k):
        plt.plot(time, x[:, i] + i*5.0)
    plt.xlabel("Time (s)")
    plt.ylabel("LFP (z-sc, offset)")
    plt.title("LFP Preview (first 4 channels)")
    plt.tight_layout()
    plt.savefig("figs/lfp_preview.png")
    plt.close()

    plt.figure()
    im = plt.imshow(phase_grid[idx], origin='lower')
    plt.colorbar(im)
    plt.title("Phase Grid Snapshot")
    plt.tight_layout()
    plt.savefig("figs/phase_preview.png")
    plt.close()

    plt.figure()
    im2 = plt.imshow(holo[idx], origin='lower')
    plt.colorbar(im2)
    plt.title("Pseudo-Hologram Intensity Snapshot")
    plt.tight_layout()
    plt.savefig("figs/hologram_intensity.png")
    plt.close()

def save_arrays(prefix: str, **arrays):
    for name, arr in arrays.items():
        np.save(f"out/{prefix}_{name}.npy", arr)

def demo(fs=1000.0, duration=5.0, channels=64, grid_side=8):
    t = np.arange(int(duration*fs)) / fs
    C = channels
    rng = np.random.default_rng(42)
    base = np.sin(2*np.pi*8.0*t)  # 8 Hz
    noise = 0.2 * rng.standard_normal((t.size, C))
    x = base[:, None] + noise + 0.05 * np.sin(2*np.pi*0.5*t)[:, None]
    x_prep = preprocess_signals(x, fs, 1.0, 150.0)
    amp, phase = estimate_phase(x_prep)
    grid = channels_to_grid(phase, grid_side)
    holo = pseudo_hologram_intensity(grid)
    save_arrays("qhca_demo", time=t, lfp=x, lfp_prep=x_prep, amp=amp, phase=phase, phase_grid=grid, hologram=holo)
    preview_plots(t, x_prep, grid, holo, idx=-1)
    return True

if __name__ == "__main__":
    # Avoid argparse to keep notebook-friendly.
    demo()