From 3463adefc53570fdc0c7f12676cb8753c6f42541 Mon Sep 17 00:00:00 2001 From: andrianj Date: Thu, 7 May 2026 17:28:52 +0200 Subject: [PATCH 01/14] Update pyBer analysis and UI tools --- pyBer/analysis_core.py | 114 ++- pyBer/ethovision_process_gui.py | 3 +- pyBer/gui_postprocessing.py | 962 +++++++++++++++----- pyBer/gui_preprocessing.py | 168 +++- pyBer/main.py | 217 ++++- pyBer/numeric_controls.py | 207 +++++ pyBer/styles.py | 79 ++ pyBer/temporal_modeling.py | 1446 +++++++++++++++++++++++++++++++ 8 files changed, 2911 insertions(+), 285 deletions(-) create mode 100644 pyBer/numeric_controls.py create mode 100644 pyBer/temporal_modeling.py diff --git a/pyBer/analysis_core.py b/pyBer/analysis_core.py index 38c883d..7205ada 100644 --- a/pyBer/analysis_core.py +++ b/pyBer/analysis_core.py @@ -285,6 +285,7 @@ class ProcessedTrial: output: Optional[np.ndarray] = None output_label: str = "" output_context: str = "" + outputs: Dict[str, np.ndarray] = field(default_factory=dict) artifact_regions_sec: Optional[List[Tuple[float, float]]] = None artifact_regions_auto_sec: Optional[List[Tuple[float, float]]] = None @@ -302,8 +303,9 @@ class ExportSelection: dio: bool = True baseline_sig: bool = True baseline_ref: bool = True + output_modes: List[str] = field(default_factory=list) - def to_dict(self) -> Dict[str, bool]: + def to_dict(self) -> Dict[str, Any]: return { "raw": bool(self.raw), "isobestic": bool(self.isobestic), @@ -311,12 +313,16 @@ def to_dict(self) -> Dict[str, bool]: "dio": bool(self.dio), "baseline_sig": bool(self.baseline_sig), "baseline_ref": bool(self.baseline_ref), + "output_modes": list(self.output_modes or []), } @classmethod def from_dict(cls, data: Optional[Dict[str, object]]) -> "ExportSelection": if not isinstance(data, dict): return cls() + modes = data.get("output_modes", []) + if not isinstance(modes, list): + modes = [] return cls( raw=bool(data.get("raw", True)), isobestic=bool(data.get("isobestic", True)), @@ -324,6 +330,7 @@ def from_dict(cls, data: Optional[Dict[str, object]]) -> "ExportSelection": dio=bool(data.get("dio", True)), baseline_sig=bool(data.get("baseline_sig", True)), baseline_ref=bool(data.get("baseline_ref", True)), + output_modes=[str(m).strip() for m in modes if str(m or "").strip()], ) @@ -367,6 +374,67 @@ def output_label_type(label: str) -> str: return "output" +def output_label_key(label: str) -> str: + """Return a stable, readable key for one output mode in multi-output exports.""" + lab = (label or "").strip() + key = re.sub(r"[^A-Za-z0-9]+", "_", lab).strip("_").lower() + replacements = { + "dff_motion_corrected_with_fitted_ref": "dff_mc_fitted_ref", + "zscore_motion_corrected_with_fitted_ref": "zscore_mc_fitted_ref", + "prominence_normalized_motion_corrected_with_fitted_ref": "prominence_mc_fitted_ref", + "dff_motion_corrected_via_subtraction": "dff_mc_subtraction", + "zscore_motion_corrected_via_subtraction": "zscore_mc_subtraction", + "dff_non_motion_corrected": "dff_non_mc", + "zscore_non_motion_corrected": "zscore_non_mc", + "zscore_subtractions": "zscore_subtractions", + "raw_signal_465": "raw_signal_465", + } + return replacements.get(key, key or "output") + + +def _unique_export_name(name: str, used: set) -> str: + base = str(name or "output").strip() or "output" + out = base + i = 2 + while out in used: + out = f"{base}_{i}" + i += 1 + used.add(out) + return out + + +def _output_items_for_export( + processed: ProcessedTrial, + selection: ExportSelection, +) -> List[Tuple[str, np.ndarray]]: + """Return output traces in requested export order, falling back to the selected output.""" + source = getattr(processed, "outputs", None) or {} + items: List[Tuple[str, np.ndarray]] = [] + + if source: + requested = [str(m).strip() for m in (selection.output_modes or []) if str(m or "").strip()] + seen = set() + for label in requested: + if label in source and label not in seen: + items.append((label, np.asarray(source[label], float))) + seen.add(label) + for label, values in source.items(): + if label not in seen: + items.append((str(label), np.asarray(values, float))) + seen.add(label) + + if not items and processed.output is not None: + label = str(processed.output_label or "output") + items.append((label, np.asarray(processed.output, float))) + + return items + + +def _csv_output_values(label: str, values: np.ndarray, t: np.ndarray) -> np.ndarray: + values = values if values.size == t.size else np.full_like(t, np.nan) + return np.asarray(values, float) + + def export_processed_csv( path: str, processed: ProcessedTrial, @@ -389,10 +457,15 @@ def export_processed_csv( if selection.dio and processed.dio is not None and processed.dio.size == t.size: dio = np.asarray(processed.dio, float) - out_col = output_label_type(processed.output_label) + output_items = _output_items_for_export(processed, selection) if selection.output else [] with open(path, "w", newline="") as f: w = csv.writer(f) + w.writerow([f"# output_label: {processed.output_label}"]) + if processed.output_context: + w.writerow([f"# output_context: {processed.output_context}"]) + if output_items: + w.writerow([f"# output_modes: {json.dumps([label for label, _ in output_items])}"]) if metadata: for k, v in metadata.items(): w.writerow([f"# {k}: {v}"]) @@ -401,8 +474,20 @@ def export_processed_csv( columns.append(("raw", raw)) if selection.isobestic: columns.append(("isobestic", iso)) - if selection.output: - columns.append((out_col, out)) + if selection.output and output_items: + if len(output_items) == 1: + label, values = output_items[0] + values = _csv_output_values(label, values, t) + columns.append((output_label_type(label), values)) + else: + used_names = {name for name, _ in columns} + primary_label, primary_values = output_items[0] + primary_values = _csv_output_values(primary_label, primary_values, t) + columns.append((_unique_export_name("output", used_names), primary_values)) + for label, values in output_items: + values = _csv_output_values(label, values, t) + col = _unique_export_name(f"output__{output_label_key(label)}", used_names) + columns.append((col, values)) # Primary trigger name (e.g. "DIO01") primary_name = str(processed.dio_name) if processed.dio_name else "dio" @@ -435,16 +520,32 @@ def export_processed_h5( g = f.create_group("data") g.create_dataset("time", data=np.asarray(processed.time, float), compression="gzip") out_type = output_label_type(processed.output_label) + output_items = _output_items_for_export(processed, selection) if selection.output else [] g.attrs["output_label"] = str(processed.output_label) g.attrs["output_context"] = str(processed.output_context) g.attrs["output_type"] = str(out_type) + if output_items: + g.attrs["output_modes"] = json.dumps([label for label, _ in output_items]) g.attrs["fs_actual"] = float(processed.fs_actual) g.attrs["fs_used"] = float(processed.fs_used) g.attrs["fs_target"] = float(processed.fs_target) g.attrs["export_selection"] = json.dumps(selection.to_dict()) - if selection.output: - g.create_dataset("output", data=np.asarray(processed.output, float), compression="gzip") + if selection.output and output_items: + t = np.asarray(processed.time, float) + _primary_label, primary_output = output_items[0] + primary_output = primary_output if primary_output.size == t.size else np.full_like(t, np.nan) + g.create_dataset("output", data=np.asarray(primary_output, float), compression="gzip") + + if len(output_items) > 1: + out_group = g.create_group("outputs") + used = set() + for label, values in output_items: + values = values if values.size == t.size else np.full_like(t, np.nan) + ds_name = _unique_export_name(output_label_key(label), used) + ds = out_group.create_dataset(ds_name, data=np.asarray(values, float), compression="gzip") + ds.attrs["label"] = str(label) + ds.attrs["output_type"] = str(output_label_type(label)) raw_sig = np.asarray(processed.raw_signal if processed.raw_signal is not None else np.full_like(processed.time, np.nan), float) raw_ref = np.asarray(processed.raw_reference if processed.raw_reference is not None else np.full_like(processed.time, np.nan), float) @@ -1567,6 +1668,7 @@ def process_trial( output=out, output_label=mode, output_context=output_context, + outputs={mode: np.asarray(out, float)} if out is not None else {}, artifact_regions_sec=regions_from_mask(t, mask), artifact_regions_auto_sec=auto_regions, fs_actual=float(fs), diff --git a/pyBer/ethovision_process_gui.py b/pyBer/ethovision_process_gui.py index af66169..8a139db 100644 --- a/pyBer/ethovision_process_gui.py +++ b/pyBer/ethovision_process_gui.py @@ -32,6 +32,8 @@ import pyqtgraph as pg +from analysis_core import coerce_time_value + MISSING_MARKERS = {"", " ", "-", "NaN", "nan", "NAN", "n/a", "N/A", None} @@ -208,7 +210,6 @@ def clean_sheet( time_orig = df[time_col].copy() df[time_col] = pd.to_numeric(df[time_col], errors="coerce") if df[time_col].isna().all(): - from pyBer.analysis_core import coerce_time_value df[time_col] = time_orig.astype(str).apply(coerce_time_value) df = df.loc[df[time_col].notna()].copy() df = df.sort_values(time_col).reset_index(drop=True) diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index 15f78f5..ed14494 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -16,8 +16,9 @@ from pyqtgraph.dockarea import DockArea, Dock import h5py -from analysis_core import ProcessedTrial +from analysis_core import ProcessedTrial, coerce_time_value from ethovision_process_gui import clean_sheet +from temporal_modeling import TemporalModelingWidget _DOCK_STATE_VERSION = 3 _POST_DOCK_STATE_KEY = "post_main_dock_state_v4" @@ -28,14 +29,15 @@ _PRE_DOCK_PREFIX = "pre." _BEHAVIOR_PARSE_BINARY = "binary_columns" _BEHAVIOR_PARSE_TIMESTAMPS = "timestamp_columns" -_FIXED_POST_RIGHT_SECTIONS = frozenset({"setup", "spatial", "psth", "export"}) -_FIXED_POST_VISIBLE_SECTIONS = frozenset({"setup", "spatial", "psth", "export"}) -_FIXED_POST_RIGHT_TAB_ORDER = ("setup", "psth", "spatial", "export") +_FIXED_POST_RIGHT_SECTIONS = frozenset({"setup", "spatial", "psth", "export", "temporal"}) +_FIXED_POST_VISIBLE_SECTIONS = frozenset({"setup", "spatial", "psth", "export", "temporal"}) +_FIXED_POST_RIGHT_TAB_ORDER = ("setup", "psth", "spatial", "temporal", "export") _POST_RIGHT_PANEL_MIN_WIDTH = 420 _FIXED_POST_RIGHT_TAB_TITLES: Dict[str, str] = { "setup": "Setup", "psth": "PSTH", "spatial": "Spatial", + "temporal": "Temporal", "export": "Export", } _USE_PG_DOCKAREA_POST_LAYOUT = True @@ -248,7 +250,6 @@ def _detect_time_column(df, fallback_to_first: bool = False) -> Optional[str]: def _numeric_column_array(df, col_name: str) -> np.ndarray: import pandas as pd - from pyBer.analysis_core import coerce_time_value col_key = None for c in df.columns: @@ -508,6 +509,7 @@ def __init__(self, parent=None) -> None: self._event_labels: List[pg.TextItem] = [] self._event_regions: List[pg.LinearRegionItem] = [] self._signal_peak_lines: List[pg.InfiniteLine] = [] + self._signal_noise_items: List[object] = [] self._pre_region: Optional[pg.LinearRegionItem] = None self._post_region: Optional[pg.LinearRegionItem] = None self._settings = QtCore.QSettings("FiberPhotometryApp", "DoricProcessor") @@ -624,6 +626,7 @@ def _build_ui(self) -> None: vsrc.addWidget(self.btn_refresh_dio) grp_align = QtWidgets.QGroupBox("Behavior / Events") + grp_align.setSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, QtWidgets.QSizePolicy.Policy.Expanding) fal = QtWidgets.QFormLayout(grp_align) fal.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) fal.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) @@ -685,12 +688,20 @@ def _build_ui(self) -> None: # Preprocessed files list self.list_preprocessed = FileDropList() - self.list_preprocessed.setMaximumHeight(120) + self.list_preprocessed.setMinimumHeight(180) + self.list_preprocessed.setSizePolicy( + QtWidgets.QSizePolicy.Policy.Expanding, + QtWidgets.QSizePolicy.Policy.Expanding, + ) self.list_preprocessed.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.MultiSelection) # Behaviors list self.list_behaviors = FileDropList() - self.list_behaviors.setMaximumHeight(120) + self.list_behaviors.setMinimumHeight(180) + self.list_behaviors.setSizePolicy( + QtWidgets.QSizePolicy.Policy.Expanding, + QtWidgets.QSizePolicy.Policy.Expanding, + ) self.list_behaviors.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.MultiSelection) # Control buttons for ordering @@ -724,6 +735,8 @@ def _build_ui(self) -> None: beh_col = QtWidgets.QVBoxLayout() beh_col.addWidget(self.list_behaviors) beh_col.addWidget(self.btn_remove_beh) + pre_col.setStretch(0, 1) + beh_col.setStretch(0, 1) lists_layout.addLayout(pre_col) lists_layout.addLayout(beh_col) @@ -764,16 +777,20 @@ def _build_ui(self) -> None: fal.addRow(self.lbl_trans_gap, self.spin_transition_gap) - grp_opt = QtWidgets.QGroupBox("PSTH Options") - fopt = QtWidgets.QFormLayout(grp_opt) - fopt.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) - fopt.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) + # ── Shared QSS for PSTH subsection headers ── + _psth_section_qss = ( + "QGroupBox { font-weight: 700; font-size: 11px; " + "border: 1px solid rgba(255,255,255,0.07); border-radius: 6px; " + "margin-top: 10px; padding: 14px 8px 8px 8px; }" + "QGroupBox::title { subcontrol-origin: margin; left: 10px; " + "padding: 0 6px; color: #8899b0; }" + ) + # ── Widget creation (unchanged logic, reordered for sections) ── self.spin_pre = QtWidgets.QDoubleSpinBox(); self.spin_pre.setRange(0.1, 60); self.spin_pre.setValue(2.0); self.spin_pre.setDecimals(2) - self.spin_post= QtWidgets.QDoubleSpinBox(); self.spin_post.setRange(0.1, 120); self.spin_post.setValue(5.0); self.spin_post.setDecimals(2) - self.spin_b0 = QtWidgets.QDoubleSpinBox(); self.spin_b0.setRange(-60, 0); self.spin_b0.setValue(-1.0); self.spin_b0.setDecimals(2) - self.spin_b1 = QtWidgets.QDoubleSpinBox(); self.spin_b1.setRange(-60, 0); self.spin_b1.setValue(0.0); self.spin_b1.setDecimals(2) - + self.spin_post = QtWidgets.QDoubleSpinBox(); self.spin_post.setRange(0.1, 120); self.spin_post.setValue(5.0); self.spin_post.setDecimals(2) + self.spin_b0 = QtWidgets.QDoubleSpinBox(); self.spin_b0.setRange(-60, 0); self.spin_b0.setValue(-1.0); self.spin_b0.setDecimals(2) + self.spin_b1 = QtWidgets.QDoubleSpinBox(); self.spin_b1.setRange(-60, 0); self.spin_b1.setValue(0.0); self.spin_b1.setDecimals(2) self.spin_resample = QtWidgets.QDoubleSpinBox(); self.spin_resample.setRange(1, 1000); self.spin_resample.setValue(50); self.spin_resample.setDecimals(1) self.spin_smooth = QtWidgets.QDoubleSpinBox(); self.spin_smooth.setRange(0, 5); self.spin_smooth.setValue(0.0); self.spin_smooth.setDecimals(2) @@ -787,6 +804,7 @@ def _build_ui(self) -> None: self.spin_group_window = QtWidgets.QDoubleSpinBox(); self.spin_group_window.setRange(0.0, 1e6); self.spin_group_window.setValue(0.0); self.spin_group_window.setDecimals(3) self.spin_dur_min = QtWidgets.QDoubleSpinBox(); self.spin_dur_min.setRange(0, 1e6); self.spin_dur_min.setValue(0.0); self.spin_dur_min.setDecimals(2) self.spin_dur_max = QtWidgets.QDoubleSpinBox(); self.spin_dur_max.setRange(0, 1e6); self.spin_dur_max.setValue(0.0); self.spin_dur_max.setDecimals(2) + self.cb_metrics = QtWidgets.QCheckBox("Enable PSTH metrics") self.cb_metrics.setChecked(True) self.btn_hide_metrics = QtWidgets.QToolButton() @@ -800,163 +818,133 @@ def _build_ui(self) -> None: self.spin_metric_post0 = QtWidgets.QDoubleSpinBox(); self.spin_metric_post0.setRange(0, 120); self.spin_metric_post0.setValue(0.0); self.spin_metric_post0.setDecimals(2) self.spin_metric_post1 = QtWidgets.QDoubleSpinBox(); self.spin_metric_post1.setRange(0, 120); self.spin_metric_post1.setValue(1.0); self.spin_metric_post1.setDecimals(2) + self.cb_global_metrics = QtWidgets.QCheckBox("Enable global metrics") + self.cb_global_metrics.setChecked(True) + self.spin_global_start = QtWidgets.QDoubleSpinBox(); self.spin_global_start.setRange(-1e6, 1e6); self.spin_global_start.setValue(0.0); self.spin_global_start.setDecimals(2) + self.spin_global_end = QtWidgets.QDoubleSpinBox(); self.spin_global_end.setRange(-1e6, 1e6); self.spin_global_end.setValue(0.0); self.spin_global_end.setDecimals(2) + self.cb_global_amp = QtWidgets.QCheckBox("Peak amplitude") + self.cb_global_amp.setChecked(True) + self.cb_global_freq = QtWidgets.QCheckBox("Transient frequency") + self.cb_global_freq.setChecked(True) + self.lbl_global_metrics = QtWidgets.QLabel("Global metrics: -") + self.lbl_global_metrics.setProperty("class", "hint") + for w in ( self.spin_pre, self.spin_post, self.spin_b0, self.spin_b1, self.spin_resample, self.spin_smooth, - self.spin_event_start, self.spin_event_end, self.spin_group_window, self.spin_dur_min, self.spin_dur_max, - self.spin_metric_pre0, self.spin_metric_pre1, self.spin_metric_post0, self.spin_metric_post1, + self.spin_event_start, self.spin_event_end, self.spin_group_window, + self.spin_dur_min, self.spin_dur_max, + self.spin_metric_pre0, self.spin_metric_pre1, + self.spin_metric_post0, self.spin_metric_post1, + self.spin_global_start, self.spin_global_end, ): w.setMinimumWidth(60) w.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) - win_row = QtWidgets.QGridLayout() - win_row.setHorizontalSpacing(6) - win_row.setContentsMargins(0, 0, 0, 0) - win_pre = QtWidgets.QLabel("Pre:") - win_post = QtWidgets.QLabel("Post:") - win_pre.setMinimumWidth(35) - win_post.setMinimumWidth(35) - win_row.addWidget(win_pre, 0, 0) - win_row.addWidget(self.spin_pre, 0, 1) - win_row.addWidget(win_post, 0, 2) - win_row.addWidget(self.spin_post, 0, 3) - win_row.setColumnStretch(1, 1) - win_row.setColumnStretch(3, 1) - win_widget = QtWidgets.QWidget(); win_widget.setLayout(win_row) - - base_row = QtWidgets.QGridLayout() - base_row.setHorizontalSpacing(6) - base_row.setContentsMargins(0, 0, 0, 0) - base_start = QtWidgets.QLabel("Start:") - base_end = QtWidgets.QLabel("End:") - base_start.setMinimumWidth(45) - base_end.setMinimumWidth(35) - base_row.addWidget(base_start, 0, 0) - base_row.addWidget(self.spin_b0, 0, 1) - base_row.addWidget(base_end, 0, 2) - base_row.addWidget(self.spin_b1, 0, 3) - base_row.setColumnStretch(1, 1) - base_row.setColumnStretch(3, 1) - base_widget = QtWidgets.QWidget(); base_widget.setLayout(base_row) - - metric_pre_row = QtWidgets.QGridLayout() - metric_pre_row.setHorizontalSpacing(6) - metric_pre_row.setContentsMargins(0, 0, 0, 0) - metric_pre_start = QtWidgets.QLabel("Start:") - metric_pre_end = QtWidgets.QLabel("End:") - metric_pre_start.setMinimumWidth(45) - metric_pre_end.setMinimumWidth(35) - metric_pre_row.addWidget(metric_pre_start, 0, 0) - metric_pre_row.addWidget(self.spin_metric_pre0, 0, 1) - metric_pre_row.addWidget(metric_pre_end, 0, 2) - metric_pre_row.addWidget(self.spin_metric_pre1, 0, 3) - metric_pre_row.setColumnStretch(1, 1) - metric_pre_row.setColumnStretch(3, 1) - metric_pre_widget = QtWidgets.QWidget(); metric_pre_widget.setLayout(metric_pre_row) - - metric_post_row = QtWidgets.QGridLayout() - metric_post_row.setHorizontalSpacing(6) - metric_post_row.setContentsMargins(0, 0, 0, 0) - metric_post_start = QtWidgets.QLabel("Start:") - metric_post_end = QtWidgets.QLabel("End:") - metric_post_start.setMinimumWidth(45) - metric_post_end.setMinimumWidth(35) - metric_post_row.addWidget(metric_post_start, 0, 0) - metric_post_row.addWidget(self.spin_metric_post0, 0, 1) - metric_post_row.addWidget(metric_post_end, 0, 2) - metric_post_row.addWidget(self.spin_metric_post1, 0, 3) - metric_post_row.setColumnStretch(1, 1) - metric_post_row.setColumnStretch(3, 1) - metric_post_widget = QtWidgets.QWidget(); metric_post_widget.setLayout(metric_post_row) - - fopt.addRow("Window (s)", win_widget) - fopt.addRow("Baseline (s)", base_widget) - fopt.addRow("Resample (Hz)", self.spin_resample) + # ── Helper: dual-spin row ── + def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): + g = QtWidgets.QGridLayout() + g.setHorizontalSpacing(6); g.setContentsMargins(0, 0, 0, 0) + la = QtWidgets.QLabel(lbl_a); la.setMinimumWidth(35) + lb = QtWidgets.QLabel(lbl_b); lb.setMinimumWidth(35) + g.addWidget(la, 0, 0); g.addWidget(w_a, 0, 1) + g.addWidget(lb, 0, 2); g.addWidget(w_b, 0, 3) + g.setColumnStretch(1, 1); g.setColumnStretch(3, 1) + w = QtWidgets.QWidget(); w.setLayout(g); return w + + win_widget = _dual_row("Pre:", self.spin_pre, "Post:", self.spin_post) + base_widget = _dual_row("Start:", self.spin_b0, "End:", self.spin_b1) + metric_pre_widget = _dual_row("Start:", self.spin_metric_pre0, "End:", self.spin_metric_pre1) + metric_post_widget = _dual_row("Start:", self.spin_metric_post0, "End:", self.spin_metric_post1) + global_widget = _dual_row("Start:", self.spin_global_start, "End:", self.spin_global_end) + + # ═══════════════════════════════════════════════════════ + # Section 1 — Window & Baseline + # ═══════════════════════════════════════════════════════ + grp_window = QtWidgets.QGroupBox("Window && baseline") + grp_window.setStyleSheet(_psth_section_qss) + fw = QtWidgets.QFormLayout(grp_window) + fw.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) + fw.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) + fw.addRow("Window (s)", win_widget) + fw.addRow("Baseline (s)", base_widget) + fw.addRow("Resample (Hz)", self.spin_resample) + fw.addRow("Smooth sigma (s)", self.spin_smooth) + + # ═══════════════════════════════════════════════════════ + # Section 2 — Event filters + # ═══════════════════════════════════════════════════════ + grp_filt = QtWidgets.QGroupBox("Event filters") + grp_filt.setStyleSheet(_psth_section_qss) + ff = QtWidgets.QFormLayout(grp_filt) + ff.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) + ff.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) filt_row = QtWidgets.QHBoxLayout() - filt_row.setContentsMargins(0, 0, 0, 0) - filt_row.setSpacing(6) - filt_row.addWidget(self.cb_filter_events) - filt_row.addStretch(1) + filt_row.setContentsMargins(0, 0, 0, 0); filt_row.setSpacing(6) + filt_row.addWidget(self.cb_filter_events); filt_row.addStretch(1) filt_row.addWidget(self.btn_hide_filters) filt_widget = QtWidgets.QWidget(); filt_widget.setLayout(filt_row) - fopt.addRow(filt_widget) - self.lbl_event_start = QtWidgets.QLabel("Event index start (1-based)") - self.lbl_event_end = QtWidgets.QLabel("Event index end (0=all)") - self.lbl_group_window = QtWidgets.QLabel("Group events within (s) (0=off)") - self.lbl_dur_min = QtWidgets.QLabel("Event duration min (s)") - self.lbl_dur_max = QtWidgets.QLabel("Event duration max (s)") - fopt.addRow(self.lbl_event_start, self.spin_event_start) - fopt.addRow(self.lbl_event_end, self.spin_event_end) - fopt.addRow(self.lbl_group_window, self.spin_group_window) - fopt.addRow(self.lbl_dur_min, self.spin_dur_min) - fopt.addRow(self.lbl_dur_max, self.spin_dur_max) - fopt.addRow("Gaussian smooth sigma (s)", self.spin_smooth) + ff.addRow(filt_widget) + self.lbl_event_start = QtWidgets.QLabel("Start index (1-based)") + self.lbl_event_end = QtWidgets.QLabel("End index (0 = all)") + self.lbl_group_window = QtWidgets.QLabel("Group within (s)") + self.lbl_dur_min = QtWidgets.QLabel("Duration min (s)") + self.lbl_dur_max = QtWidgets.QLabel("Duration max (s)") + ff.addRow(self.lbl_event_start, self.spin_event_start) + ff.addRow(self.lbl_event_end, self.spin_event_end) + ff.addRow(self.lbl_group_window, self.spin_group_window) + ff.addRow(self.lbl_dur_min, self.spin_dur_min) + ff.addRow(self.lbl_dur_max, self.spin_dur_max) + + # ═══════════════════════════════════════════════════════ + # Section 3 — PSTH metrics + # ═══════════════════════════════════════════════════════ + grp_met = QtWidgets.QGroupBox("PSTH metrics") + grp_met.setStyleSheet(_psth_section_qss) + fm = QtWidgets.QFormLayout(grp_met) + fm.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) + fm.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) met_row = QtWidgets.QHBoxLayout() - met_row.setContentsMargins(0, 0, 0, 0) - met_row.setSpacing(6) - met_row.addWidget(self.cb_metrics) - met_row.addStretch(1) + met_row.setContentsMargins(0, 0, 0, 0); met_row.setSpacing(6) + met_row.addWidget(self.cb_metrics); met_row.addStretch(1) met_row.addWidget(self.btn_hide_metrics) met_widget = QtWidgets.QWidget(); met_widget.setLayout(met_row) - fopt.addRow(met_widget) + fm.addRow(met_widget) self.lbl_metric = QtWidgets.QLabel("Metric") - self.lbl_metric_pre = QtWidgets.QLabel("Metric pre (s)") - self.lbl_metric_post = QtWidgets.QLabel("Metric post (s)") - fopt.addRow(self.lbl_metric, self.combo_metric) - fopt.addRow(self.lbl_metric_pre, metric_pre_widget) - fopt.addRow(self.lbl_metric_post, metric_post_widget) - - self.cb_global_metrics = QtWidgets.QCheckBox("Enable global metrics") - self.cb_global_metrics.setChecked(True) - fopt.addRow(self.cb_global_metrics) - - self.spin_global_start = QtWidgets.QDoubleSpinBox() - self.spin_global_start.setRange(-1e6, 1e6) - self.spin_global_start.setValue(0.0) - self.spin_global_start.setDecimals(2) - self.spin_global_end = QtWidgets.QDoubleSpinBox() - self.spin_global_end.setRange(-1e6, 1e6) - self.spin_global_end.setValue(0.0) - self.spin_global_end.setDecimals(2) - self.spin_global_start.setMinimumWidth(60) - self.spin_global_end.setMinimumWidth(60) - self.spin_global_start.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) - self.spin_global_end.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) - - global_row = QtWidgets.QGridLayout() - global_row.setHorizontalSpacing(6) - global_row.setContentsMargins(0, 0, 0, 0) - global_row.addWidget(QtWidgets.QLabel("Start:"), 0, 0) - global_row.addWidget(self.spin_global_start, 0, 1) - global_row.addWidget(QtWidgets.QLabel("End:"), 0, 2) - global_row.addWidget(self.spin_global_end, 0, 3) - global_row.setColumnStretch(1, 1) - global_row.setColumnStretch(3, 1) - global_widget = QtWidgets.QWidget() - global_widget.setLayout(global_row) - fopt.addRow("Global range (s)", global_widget) - - self.cb_global_amp = QtWidgets.QCheckBox("Peak amplitude") - self.cb_global_amp.setChecked(True) - self.cb_global_freq = QtWidgets.QCheckBox("Transient frequency") - self.cb_global_freq.setChecked(True) + self.lbl_metric_pre = QtWidgets.QLabel("Pre window (s)") + self.lbl_metric_post = QtWidgets.QLabel("Post window (s)") + fm.addRow(self.lbl_metric, self.combo_metric) + fm.addRow(self.lbl_metric_pre, metric_pre_widget) + fm.addRow(self.lbl_metric_post, metric_post_widget) + + # ═══════════════════════════════════════════════════════ + # Section 4 — Global metrics + # ═══════════════════════════════════════════════════════ + grp_global = QtWidgets.QGroupBox("Global metrics") + grp_global.setStyleSheet(_psth_section_qss) + fg = QtWidgets.QFormLayout(grp_global) + fg.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows) + fg.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop) + fg.addRow(self.cb_global_metrics) + fg.addRow("Range (s)", global_widget) global_opts = QtWidgets.QHBoxLayout() - global_opts.setContentsMargins(0, 0, 0, 0) - global_opts.setSpacing(6) + global_opts.setContentsMargins(0, 0, 0, 0); global_opts.setSpacing(6) global_opts.addWidget(self.cb_global_amp) global_opts.addWidget(self.cb_global_freq) global_opts.addStretch(1) - global_opts_widget = QtWidgets.QWidget() - global_opts_widget.setLayout(global_opts) - fopt.addRow("Global metrics", global_opts_widget) - - self.lbl_global_metrics = QtWidgets.QLabel("Global metrics: -") - self.lbl_global_metrics.setProperty("class", "hint") - fopt.addRow("", self.lbl_global_metrics) - - for w in (self.spin_global_start, self.spin_global_end): - w.setMinimumWidth(60) - w.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) + global_opts_widget = QtWidgets.QWidget(); global_opts_widget.setLayout(global_opts) + fg.addRow("Compute", global_opts_widget) + fg.addRow("", self.lbl_global_metrics) + + # Container for all subsections (replaces old grp_opt) + grp_opt = QtWidgets.QWidget() + _psth_vbox = QtWidgets.QVBoxLayout(grp_opt) + _psth_vbox.setContentsMargins(0, 0, 0, 0) + _psth_vbox.setSpacing(4) + _psth_vbox.addWidget(grp_window) + _psth_vbox.addWidget(grp_filt) + _psth_vbox.addWidget(grp_met) + _psth_vbox.addWidget(grp_global) self.btn_compute = QtWidgets.QPushButton("Postprocessing (compute PSTH)") self.btn_compute.setProperty("class", "compactPrimarySmall") @@ -1007,6 +995,16 @@ def _build_ui(self) -> None: self.spin_peak_prominence.setRange(0.0, 1e6) self.spin_peak_prominence.setValue(0.5) self.spin_peak_prominence.setDecimals(4) + self.cb_peak_auto_mad = QtWidgets.QCheckBox("Auto transient threshold (MAD noise)") + self.cb_peak_auto_mad.setChecked(False) + self.cb_peak_auto_mad.setToolTip( + "Estimate trace noise as 1.4826 x MAD after baseline/smoothing and use " + "the multiplier below as the minimum peak prominence." + ) + self.spin_peak_mad_multiplier = QtWidgets.QDoubleSpinBox() + self.spin_peak_mad_multiplier.setRange(0.5, 50.0) + self.spin_peak_mad_multiplier.setValue(5.0) + self.spin_peak_mad_multiplier.setDecimals(2) self.spin_peak_height = QtWidgets.QDoubleSpinBox() self.spin_peak_height.setRange(0.0, 1e6) self.spin_peak_height.setValue(0.0) @@ -1033,8 +1031,19 @@ def _build_ui(self) -> None: self.spin_peak_auc_window.setRange(0.0, 30.0) self.spin_peak_auc_window.setValue(0.5) self.spin_peak_auc_window.setDecimals(3) + self.cb_peak_norm_prominence = QtWidgets.QCheckBox("Baseline-prominence normalized amplitude") + self.cb_peak_norm_prominence.setChecked(False) + self.cb_peak_norm_prominence.setToolTip( + "Report peak amplitude after scaling by the top baseline peak prominences." + ) self.cb_peak_overlay = QtWidgets.QCheckBox("Show detected peaks on trace") self.cb_peak_overlay.setChecked(True) + self.cb_peak_noise_overlay = QtWidgets.QCheckBox("Show noise trace / MAD threshold overlay") + self.cb_peak_noise_overlay.setChecked(False) + self.cb_peak_noise_overlay.setToolTip( + "After peak detection, overlay the preprocessed detection trace, robust noise band, " + "and effective prominence threshold used for the visible file." + ) self.btn_detect_peaks = QtWidgets.QPushButton("Detect peaks") self.btn_detect_peaks.setProperty("class", "compactPrimarySmall") self.btn_export_peaks = QtWidgets.QPushButton("Export peaks CSV") @@ -1046,14 +1055,18 @@ def _build_ui(self) -> None: f_signal.addRow("File", self.combo_signal_file) f_signal.addRow("Method", self.combo_signal_method) f_signal.addRow("Min prominence", self.spin_peak_prominence) + f_signal.addRow(self.cb_peak_auto_mad) + f_signal.addRow("MAD multiplier", self.spin_peak_mad_multiplier) f_signal.addRow("Min height (0=off)", self.spin_peak_height) f_signal.addRow("Min distance (s)", self.spin_peak_distance) f_signal.addRow("Smooth sigma (s)", self.spin_peak_smooth) f_signal.addRow("Baseline handling", self.combo_peak_baseline) f_signal.addRow("Baseline window (s)", self.spin_peak_baseline_window) + f_signal.addRow(self.cb_peak_norm_prominence) f_signal.addRow("Rate bin (s)", self.spin_peak_rate_bin) f_signal.addRow("AUC window (+/- s)", self.spin_peak_auc_window) f_signal.addRow(self.cb_peak_overlay) + f_signal.addRow(self.cb_peak_noise_overlay) signal_btn_row = QtWidgets.QHBoxLayout() signal_btn_row.addWidget(self.btn_detect_peaks) signal_btn_row.addWidget(self.btn_export_peaks) @@ -1258,7 +1271,7 @@ def _build_ui(self) -> None: setup_layout.setContentsMargins(6, 6, 6, 6) setup_layout.setSpacing(8) setup_layout.addWidget(grp_src) - setup_layout.addWidget(grp_align) + setup_layout.addWidget(grp_align, stretch=1) setup_btn_row = QtWidgets.QHBoxLayout() self.btn_setup_load = QtWidgets.QPushButton("Load") self.btn_setup_load.setProperty("class", "compactPrimarySmall") @@ -1303,6 +1316,11 @@ def _build_ui(self) -> None: behavior_layout.addWidget(self.lbl_behavior_summary) behavior_layout.addStretch(1) + self.section_temporal = TemporalModelingWidget() + self.section_temporal.statusMessage.connect( + lambda msg, ms: self.statusUpdate.emit(msg, ms) + ) + self.section_spatial = QtWidgets.QWidget() spatial_layout = QtWidgets.QVBoxLayout(self.section_spatial) spatial_layout.setContentsMargins(6, 6, 6, 6) @@ -1362,6 +1380,7 @@ def _build_ui(self) -> None: self.btn_panel_export = QtWidgets.QPushButton("Export") self.btn_panel_signal = QtWidgets.QPushButton("Signal") self.btn_panel_behavior = QtWidgets.QPushButton("Behavior") + self.btn_panel_temporal = QtWidgets.QPushButton("Temporal") self._section_buttons = { "setup": self.btn_panel_setup, "psth": self.btn_panel_psth, @@ -1369,6 +1388,7 @@ def _build_ui(self) -> None: "export": self.btn_panel_export, "signal": self.btn_panel_signal, "behavior": self.btn_panel_behavior, + "temporal": self.btn_panel_temporal, } for b in self._section_buttons.values(): b.setCheckable(True) @@ -1378,7 +1398,7 @@ def _build_ui(self) -> None: # and put workflow actions in a thin transport bar above the plots. from styles import ( _make_icon, _paint_sliders, _paint_chart, _paint_grid, - _paint_export, _paint_pulse, _paint_paw, + _paint_export, _paint_pulse, _paint_paw, _paint_temporal, ) _post_rail_meta = { "setup": ("Setup", _paint_sliders), @@ -1387,6 +1407,7 @@ def _build_ui(self) -> None: "export": ("Export panel", _paint_export), "signal": ("Signal events", _paint_pulse), "behavior": ("Behavior", _paint_paw), + "temporal": ("Temporal modeling (GLM / FLMM)", _paint_temporal), } for key, btn in self._section_buttons.items(): tip, painter = _post_rail_meta[key] @@ -1404,7 +1425,7 @@ def _build_ui(self) -> None: rail_layout = QtWidgets.QVBoxLayout(self._post_side_rail) rail_layout.setContentsMargins(8, 10, 8, 10) rail_layout.setSpacing(6) - for key in ("setup", "psth", "spatial", "signal", "behavior", "export"): + for key in ("setup", "psth", "spatial", "temporal", "signal", "behavior", "export"): rail_layout.addWidget(self._section_buttons[key], 0, QtCore.Qt.AlignmentFlag.AlignHCenter) rail_layout.addStretch(1) @@ -1520,6 +1541,7 @@ def _build_ui(self) -> None: self.curve_trace = self.plot_trace.plot(pen=pg.mkPen(self._style["trace"], width=1.1)) self.curve_behavior = self.plot_trace.plot(pen=pg.mkPen(self._style["behavior"], width=1.0)) + self.curve_behavior.setVisible(False) self.curve_peak_markers = self.plot_trace.plot( pen=None, symbol="o", @@ -1889,6 +1911,10 @@ def _build_ui(self) -> None: self.cb_peak_overlay.toggled.connect(self._refresh_signal_overlay) self.combo_signal_source.currentIndexChanged.connect(self._refresh_signal_file_combo) self.combo_signal_scope.currentIndexChanged.connect(self._refresh_signal_file_combo) + self.combo_signal_file.currentIndexChanged.connect(self._on_signal_file_changed) + self.cb_peak_auto_mad.toggled.connect(self._update_peak_auto_mad_enabled) + self.cb_peak_noise_overlay.toggled.connect(self._refresh_signal_overlay) + self.cb_peak_norm_prominence.toggled.connect(lambda _checked=False: self._save_settings()) self.tab_sources.currentChanged.connect(self._refresh_signal_file_combo) self.tab_visual_mode.currentChanged.connect(self._on_visual_mode_changed) self.combo_individual_file.currentIndexChanged.connect(self._on_individual_file_changed) @@ -2025,6 +2051,7 @@ def _section_widget_map(self) -> Dict[str, Tuple[str, QtWidgets.QWidget]]: "setup": ("Setup", self.section_setup), "psth": ("PSTH", self.section_psth), "spatial": ("Spatial", self.section_spatial), + "temporal": ("Temporal Modeling", self.section_temporal), "export": ("Export", self.section_export), "signal": ("Signal Event Analyzer", self.section_signal), "behavior": ("Behavior Analysis", self.section_behavior), @@ -2376,6 +2403,7 @@ def _setup_section_popups(self) -> None: "signal": ("Signal Event Analyzer", self.section_signal), "behavior": ("Behavior Analysis", self.section_behavior), "spatial": ("Spatial", self.section_spatial), + "temporal": ("Temporal Modeling", self.section_temporal), "export": ("Export", self.section_export), } for key, (title, widget) in section_map.items(): @@ -3415,6 +3443,50 @@ def _refresh_signal_file_combo(self) -> None: has_multi = self.combo_signal_file.count() > 1 self.combo_signal_scope.setEnabled(has_multi) self.combo_signal_file.setEnabled(self.combo_signal_scope.currentText() == "Per file") + self._refresh_signal_overlay() + + def _on_signal_file_changed(self, _index: int = 0) -> None: + if not hasattr(self, "combo_signal_file"): + return + if self.combo_signal_source.currentText().startswith("Use PSTH input trace"): + self._refresh_signal_overlay() + return + if self.combo_signal_scope.currentText() != "Per file": + self._refresh_signal_overlay() + return + + file_id = self.combo_signal_file.currentText().strip() + if not file_id: + self._refresh_signal_overlay() + return + try: + idx = self.combo_individual_file.findText(file_id) + if idx >= 0 and self.combo_individual_file.currentIndex() != idx: + self.combo_individual_file.setCurrentIndex(idx) + else: + self._update_trace_preview() + except Exception: + self._update_trace_preview() + self._refresh_signal_overlay() + + def _current_signal_overlay_file_id(self) -> str: + if self.combo_signal_source.currentText().startswith("Use PSTH input trace"): + return "psth_trace" + if self.combo_signal_scope.currentText() == "Per file": + file_id = self.combo_signal_file.currentText().strip() + if file_id: + return file_id + try: + if self.tab_visual_mode.currentIndex() == 0: + file_id = self.combo_individual_file.currentText().strip() + if file_id: + return file_id + except Exception: + pass + if self._processed: + proc = self._processed[0] + return os.path.splitext(os.path.basename(proc.path))[0] if proc.path else "import" + return "" def _refresh_individual_file_combo(self) -> None: if not hasattr(self, "combo_individual_file"): @@ -3490,17 +3562,21 @@ def _update_data_availability(self) -> None: self.combo_signal_scope, self.combo_signal_file, self.combo_signal_method, - self.spin_peak_prominence, + self.cb_peak_auto_mad, + self.spin_peak_mad_multiplier, self.spin_peak_height, self.spin_peak_distance, self.spin_peak_smooth, self.combo_peak_baseline, self.spin_peak_baseline_window, + self.cb_peak_norm_prominence, self.spin_peak_rate_bin, self.spin_peak_auc_window, self.cb_peak_overlay, + self.cb_peak_noise_overlay, ): w.setEnabled(has_processed) + self._update_peak_auto_mad_enabled(queue=False) for w in ( self.btn_compute_behavior, self.btn_export_behavior_metrics, @@ -3600,6 +3676,14 @@ def _update_global_metrics_enabled(self) -> None: self._apply_view_layout() self._queue_settings_save() + def _update_peak_auto_mad_enabled(self, _checked: object = None, *, queue: bool = True) -> None: + has_processed = bool(self._processed) + auto_mad = bool(self.cb_peak_auto_mad.isChecked()) + self.spin_peak_prominence.setEnabled(has_processed and not auto_mad) + self.spin_peak_mad_multiplier.setEnabled(has_processed and auto_mad) + if queue: + self._queue_settings_save() + def _queue_settings_save(self, *_args: object) -> None: if self._is_restoring_settings: return @@ -3656,6 +3740,7 @@ def _wire_settings_autosave(self) -> None: self.spin_global_start, self.spin_global_end, self.spin_peak_prominence, + self.spin_peak_mad_multiplier, self.spin_peak_height, self.spin_peak_distance, self.spin_peak_smooth, @@ -3679,7 +3764,10 @@ def _wire_settings_autosave(self) -> None: self.cb_global_metrics, self.cb_global_amp, self.cb_global_freq, + self.cb_peak_auto_mad, + self.cb_peak_norm_prominence, self.cb_peak_overlay, + self.cb_peak_noise_overlay, self.cb_behavior_aligned, self.cb_spatial_clip, self.cb_spatial_time_filter, @@ -3832,7 +3920,6 @@ def _find_col(names: List[str]) -> Optional[int]: iso_vals = [] dio_vals = [] - from pyBer.analysis_core import coerce_time_value for r in data_rows: if time_idx is None or output_idx is None: continue @@ -5112,55 +5199,10 @@ def _update_trace_preview(self) -> None: self._refresh_signal_overlay() def _update_behavior_overlay(self, proc: ProcessedTrial) -> None: - if not proc or proc.time is None: - self.curve_behavior.setData([], []) - return - if not self.combo_align.currentText().startswith("Behavior"): - self.curve_behavior.setData([], []) - return - info = self._match_behavior_source(proc) - if not info: - self.curve_behavior.setData([], []) - return - behaviors = info.get("behaviors") or {} - beh = self.combo_behavior_name.currentText().strip() - if not beh and behaviors: - beh = next(iter(behaviors.keys())) - if beh not in behaviors: - self.curve_behavior.setData([], []) - return - t_proc = np.asarray(proc.time, float) - if t_proc.size == 0: - self.curve_behavior.setData([], []) - return - kind = str(info.get("kind", _BEHAVIOR_PARSE_BINARY)) - if kind == _BEHAVIOR_PARSE_TIMESTAMPS: - events = np.asarray(behaviors[beh], float) - events = events[np.isfinite(events)] - if events.size == 0: - self.curve_behavior.setData([], []) - return - events = np.sort(np.unique(events)) - marker = np.zeros_like(t_proc, dtype=float) - for ev in events: - pos = int(np.searchsorted(t_proc, ev, side="left")) - if pos <= 0: - idx = 0 - elif pos >= t_proc.size: - idx = t_proc.size - 1 - else: - idx = pos if abs(float(t_proc[pos] - ev)) <= abs(float(t_proc[pos - 1] - ev)) else (pos - 1) - marker[idx] = 1.0 - self.curve_behavior.setData(t_proc, marker, connect="finite", skipFiniteCheck=True) - return - - t = np.asarray(info.get("time", np.array([], float)), float) - if t.size == 0: - self.curve_behavior.setData([], []) - return - b = np.asarray(behaviors[beh], float) - b_interp = np.interp(t_proc, t, b) - self.curve_behavior.setData(t_proc, b_interp, connect="finite", skipFiniteCheck=True) + # The trace preview should only show the processed signal trace. + # Behavior/event data remain available for alignment and analysis, + # but the binary edge overlay is intentionally not rendered here. + self.curve_behavior.setData([], []) def _resolve_signal_detection_targets(self) -> List[Tuple[str, np.ndarray, np.ndarray]]: targets: List[Tuple[str, np.ndarray, np.ndarray]] = [] @@ -5195,14 +5237,14 @@ def _resolve_signal_detection_targets(self) -> List[Tuple[str, np.ndarray, np.nd targets.append((file_id, np.asarray(proc.time, float), np.asarray(proc.output, float))) return targets - def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: t = np.asarray(t, float) y = np.asarray(y, float) m = np.isfinite(t) & np.isfinite(y) t = t[m] y = y[m] if t.size < 3: - return np.array([], float), np.array([], float) + return np.array([], float), np.array([], float), np.array([], float) dt = float(np.nanmedian(np.diff(t))) if t.size > 2 else np.nan if not np.isfinite(dt) or dt <= 0: @@ -5214,6 +5256,7 @@ def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np if win % 2 == 0: win += 1 + y_trace = y.copy() y_proc = y.copy() if baseline_mode.endswith("rolling median"): try: @@ -5237,7 +5280,121 @@ def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np except Exception: pass - return t, y_proc + return t, y_proc, y_trace + + def _signal_baseline_prominence_stats( + self, + t: np.ndarray, + y: np.ndarray, + min_prominence: float, + ) -> Dict[str, float]: + t = np.asarray(t, float) + y = np.asarray(y, float) + finite = np.isfinite(t) & np.isfinite(y) + if t.size < 3 or y.size != t.size or not np.any(finite): + return { + "scale": np.nan, + "baseline_median": np.nan, + "n_baseline_peaks": 0.0, + "baseline_duration_s": 0.0, + "scale_source": "unavailable", + "mad_noise_sigma": np.nan, + } + + t_finite = t[finite] + t0 = float(np.nanmin(t_finite)) + baseline_window = max(0.1, float(self.spin_peak_baseline_window.value())) + keep = finite & (t <= t0 + baseline_window) + if np.sum(keep) < 3: + keep = finite + + baseline = y[keep] + baseline_median = float(np.nanmedian(baseline)) if baseline.size else np.nan + if not np.isfinite(baseline_median): + return { + "scale": np.nan, + "baseline_median": np.nan, + "n_baseline_peaks": 0.0, + "baseline_duration_s": 0.0, + "scale_source": "unavailable", + "mad_noise_sigma": np.nan, + } + + centered_baseline = np.asarray(baseline, float) - baseline_median + mad_stats = self._signal_mad_noise_stats(centered_baseline) + mad_sigma = float(mad_stats.get("noise_sigma", np.nan)) + if not np.isfinite(mad_sigma) or mad_sigma <= 1e-12: + full_centered = np.asarray(y[finite], float) - float(np.nanmedian(y[finite])) + full_mad_stats = self._signal_mad_noise_stats(full_centered) + mad_sigma = float(full_mad_stats.get("noise_sigma", np.nan)) + try: + from scipy.signal import find_peaks + peaks, props = find_peaks(centered_baseline, prominence=max(0.0, float(min_prominence))) + except Exception: + peaks = np.array([], int) + props = {} + proms = np.asarray(props.get("prominences", np.array([], float)), float) + proms = proms[np.isfinite(proms) & (proms > 1e-12)] + if proms.size == 0: + scale = mad_sigma if np.isfinite(mad_sigma) and mad_sigma > 1e-12 else np.nan + scale_source = "mad_noise_fallback" if np.isfinite(scale) else "unavailable" + else: + top_count = max(1, int(np.ceil(proms.size * 0.10))) + scale = float(np.nanmean(np.sort(proms)[-top_count:])) + scale_source = "baseline_peak_prominence" + + duration = float(np.nanmax(t[keep]) - np.nanmin(t[keep])) if np.sum(keep) >= 2 else 0.0 + return { + "scale": scale, + "baseline_median": baseline_median, + "n_baseline_peaks": float(proms.size), + "baseline_duration_s": duration, + "scale_source": scale_source, + "mad_noise_sigma": mad_sigma, + } + + @staticmethod + def _signal_mad_noise_stats(y: np.ndarray) -> Dict[str, float]: + arr = np.asarray(y, float) + arr = arr[np.isfinite(arr)] + if arr.size < 5: + return {"center": np.nan, "mad": np.nan, "noise_sigma": np.nan, "n_samples": float(arr.size)} + + center = float(np.nanmedian(arr)) + abs_dev = np.abs(arr - center) + mad = float(np.nanmedian(abs_dev)) + sigma = 1.4826 * mad + + # Re-estimate from the central mass so large transients do not inflate the noise estimate. + if np.isfinite(sigma) and sigma > 1e-12: + keep = abs_dev <= (3.0 * sigma) + if np.sum(keep) >= max(5, int(0.10 * arr.size)): + core = arr[keep] + center = float(np.nanmedian(core)) + mad = float(np.nanmedian(np.abs(core - center))) + sigma = 1.4826 * mad + + if not np.isfinite(sigma) or sigma <= 1e-12: + q25, q75 = np.nanpercentile(arr, [25.0, 75.0]) + sigma = float((q75 - q25) / 1.349) if np.isfinite(q25) and np.isfinite(q75) else np.nan + if not np.isfinite(sigma) or sigma <= 1e-12: + sigma = float(np.nanstd(arr)) + if not np.isfinite(sigma) or sigma <= 1e-12: + sigma = np.nan + + return { + "center": center, + "mad": mad, + "noise_sigma": sigma, + "n_samples": float(arr.size), + } + + @staticmethod + def _trapz_area(y: np.ndarray, x: np.ndarray) -> float: + try: + return float(np.trapezoid(y, x)) + except AttributeError: + return float(np.trapz(y, x)) def _refresh_signal_overlay(self) -> None: for ln in self._signal_peak_lines: @@ -5246,21 +5403,32 @@ def _refresh_signal_overlay(self) -> None: except Exception: pass self._signal_peak_lines = [] + for item in self._signal_noise_items: + try: + self.plot_trace.removeItem(item) + except Exception: + pass + self._signal_noise_items = [] self.curve_peak_markers.setData([], []) - if not self.cb_peak_overlay.isChecked(): - return if not self.last_signal_events or not self._processed: return - current_file = os.path.splitext(os.path.basename(self._processed[0].path))[0] if self._processed[0].path else "import" + current_file = self._current_signal_overlay_file_id() + self._draw_signal_noise_overlay(current_file) + + if not self.cb_peak_overlay.isChecked(): + return + file_ids = self.last_signal_events.get("file_ids", []) times = np.asarray(self.last_signal_events.get("peak_times_sec", np.array([], float)), float) - heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float) + heights = np.asarray(self.last_signal_events.get("peak_trace_values", np.array([], float)), float) + if heights.size != times.size: + heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float) if times.size == 0 or heights.size == 0: return - if file_ids and len(file_ids) == times.size: - mask = np.asarray([fid == current_file or fid == "psth_trace" for fid in file_ids], bool) + if current_file and file_ids and len(file_ids) == times.size: + mask = np.asarray([str(fid) == current_file for fid in file_ids], bool) times = times[mask] heights = heights[mask] if times.size == 0: @@ -5275,6 +5443,55 @@ def _refresh_signal_overlay(self) -> None: self.plot_trace.addItem(ln) self._signal_peak_lines.append(ln) + def _draw_signal_noise_overlay(self, current_file: str) -> None: + if not getattr(self, "cb_peak_noise_overlay", None) or not self.cb_peak_noise_overlay.isChecked(): + return + overlays = self.last_signal_events.get("noise_overlay_by_file", {}) if self.last_signal_events else {} + if not isinstance(overlays, dict) or not overlays: + return + overlay = overlays.get(str(current_file or "")) + if overlay is None and len(overlays) == 1: + overlay = next(iter(overlays.values())) + if not isinstance(overlay, dict): + return + + t = np.asarray(overlay.get("time", np.array([], float)), float) + y = np.asarray(overlay.get("detection_trace", np.array([], float)), float) + if t.size != y.size or t.size < 2: + return + + step = max(1, int(np.ceil(t.size / 5000))) + trace_item = pg.PlotDataItem( + t[::step], + y[::step], + pen=pg.mkPen((80, 220, 220, 150), width=1.0, style=QtCore.Qt.PenStyle.DashLine), + name="detection trace", + ) + trace_item.setZValue(8) + self.plot_trace.addItem(trace_item) + self._signal_noise_items.append(trace_item) + + center = float(overlay.get("center", np.nan)) + sigma = float(overlay.get("noise_sigma", np.nan)) + used_prominence = float(overlay.get("used_prominence", np.nan)) + for y0, color, width, style in ( + (center, (80, 220, 220, 170), 1.0, QtCore.Qt.PenStyle.DotLine), + (center + sigma, (80, 220, 220, 110), 0.8, QtCore.Qt.PenStyle.DotLine), + (center - sigma, (80, 220, 220, 110), 0.8, QtCore.Qt.PenStyle.DotLine), + (center + used_prominence, (255, 210, 80, 190), 1.2, QtCore.Qt.PenStyle.DashLine), + ): + if not np.isfinite(y0): + continue + ln = pg.InfiniteLine( + pos=float(y0), + angle=0, + pen=pg.mkPen(color, width=width, style=style), + movable=False, + ) + ln.setZValue(9) + self.plot_trace.addItem(ln) + self._signal_noise_items.append(ln) + def _detect_signal_events(self) -> None: self.last_signal_events = None targets = self._resolve_signal_detection_targets() @@ -5287,20 +5504,57 @@ def _detect_signal_events(self) -> None: all_times: List[float] = [] all_idx: List[int] = [] all_heights: List[float] = [] + all_signal_heights: List[float] = [] + all_trace_values: List[float] = [] all_proms: List[float] = [] + all_norm_proms: List[float] = [] + all_norm_scales: List[float] = [] + all_mad_sigmas: List[float] = [] + all_auto_prominence_thresholds: List[float] = [] all_widths_sec: List[float] = [] all_auc: List[float] = [] all_file_ids: List[str] = [] + normalization_by_file: Dict[str, Dict[str, float]] = {} + mad_threshold_by_file: Dict[str, Dict[str, float]] = {} + noise_overlay_by_file: Dict[str, Dict[str, object]] = {} + normalize_amplitude = bool(self.cb_peak_norm_prominence.isChecked()) + auto_mad = bool(self.cb_peak_auto_mad.isChecked()) + mad_multiplier = float(self.spin_peak_mad_multiplier.value()) for file_id, t_raw, y_raw in targets: - t, y = self._preprocess_signal_for_peaks(t_raw, y_raw) + t, y, y_trace = self._preprocess_signal_for_peaks(t_raw, y_raw) if t.size < 5: continue dt = float(np.nanmedian(np.diff(t))) if not np.isfinite(dt) or dt <= 0: continue - prominence = max(0.0, float(self.spin_peak_prominence.value())) + manual_prominence = max(0.0, float(self.spin_peak_prominence.value())) + prominence = manual_prominence + mad_stats = self._signal_mad_noise_stats(y) + mad_sigma = float(mad_stats.get("noise_sigma", np.nan)) + auto_prominence = np.nan + if auto_mad: + if np.isfinite(mad_sigma) and mad_sigma > 1e-12: + auto_prominence = max(0.0, mad_multiplier * mad_sigma) + prominence = auto_prominence + mad_threshold_by_file[str(file_id)] = { + **mad_stats, + "multiplier": mad_multiplier, + "auto_prominence": auto_prominence, + "fallback_manual_prominence": manual_prominence, + "used_prominence": prominence, + } + noise_overlay_by_file[str(file_id)] = { + "time": t.copy(), + "detection_trace": y.copy(), + "center": float(mad_stats.get("center", np.nan)), + "mad": float(mad_stats.get("mad", np.nan)), + "noise_sigma": mad_sigma, + "auto_prominence": auto_prominence, + "manual_prominence": manual_prominence, + "used_prominence": prominence, + } min_height = float(self.spin_peak_height.value()) min_distance_sec = max(0.0, float(self.spin_peak_distance.value())) min_dist_samples = max(1, int(round(min_distance_sec / dt))) if min_distance_sec > 0 else None @@ -5325,8 +5579,24 @@ def _detect_signal_events(self) -> None: if peaks.size == 0: continue - p_heights = y[peaks] + p_signal_heights = np.asarray(y[peaks], float) + p_trace_values = np.asarray(y_trace[peaks], float) if y_trace.size == y.size else p_signal_heights.copy() p_proms = np.asarray(props.get("prominences", np.full(peaks.size, np.nan)), float) + p_heights = p_signal_heights.copy() + p_norm_proms = np.full(peaks.size, np.nan, float) + p_norm_scales = np.full(peaks.size, np.nan, float) + + if normalize_amplitude: + norm_stats = self._signal_baseline_prominence_stats(t, y, prominence) + normalization_by_file[str(file_id)] = norm_stats + scale = float(norm_stats.get("scale", np.nan)) + baseline_median = float(norm_stats.get("baseline_median", np.nan)) + if np.isfinite(scale) and scale > 1e-12 and np.isfinite(baseline_median): + p_heights = (p_signal_heights - baseline_median) / scale + p_norm_proms = p_proms / scale + p_norm_scales[:] = scale + else: + p_heights = np.full(peaks.size, np.nan, float) try: widths_samp = peak_widths(y, peaks, rel_height=0.5)[0] widths_sec = np.asarray(widths_samp, float) * dt @@ -5345,12 +5615,18 @@ def _detect_signal_events(self) -> None: if i1 - i0 < 2: auc_vals.append(np.nan) continue - auc_vals.append(float(np.trapz(y[i0:i1], t[i0:i1]))) + auc_vals.append(self._trapz_area(y[i0:i1], t[i0:i1])) all_times.extend(t[peaks].tolist()) all_idx.extend(peaks.tolist()) all_heights.extend(np.asarray(p_heights, float).tolist()) + all_signal_heights.extend(np.asarray(p_signal_heights, float).tolist()) + all_trace_values.extend(np.asarray(p_trace_values, float).tolist()) all_proms.extend(np.asarray(p_proms, float).tolist()) + all_norm_proms.extend(np.asarray(p_norm_proms, float).tolist()) + all_norm_scales.extend(np.asarray(p_norm_scales, float).tolist()) + all_mad_sigmas.extend([mad_sigma] * peaks.size) + all_auto_prominence_thresholds.extend([auto_prominence] * peaks.size) all_widths_sec.extend(np.asarray(widths_sec, float).tolist()) all_auc.extend(np.asarray(auc_vals, float).tolist()) all_file_ids.extend([file_id] * peaks.size) @@ -5364,7 +5640,13 @@ def _detect_signal_events(self) -> None: peak_times = np.asarray(all_times, float) peak_idx = np.asarray(all_idx, int) peak_heights = np.asarray(all_heights, float) + peak_signal_heights = np.asarray(all_signal_heights, float) + peak_trace_values = np.asarray(all_trace_values, float) peak_proms = np.asarray(all_proms, float) + peak_norm_proms = np.asarray(all_norm_proms, float) + peak_norm_scales = np.asarray(all_norm_scales, float) + peak_mad_sigmas = np.asarray(all_mad_sigmas, float) + peak_auto_prominence_thresholds = np.asarray(all_auto_prominence_thresholds, float) peak_widths_sec = np.asarray(all_widths_sec, float) peak_auc = np.asarray(all_auc, float) @@ -5372,7 +5654,13 @@ def _detect_signal_events(self) -> None: peak_times = peak_times[sort_idx] peak_idx = peak_idx[sort_idx] peak_heights = peak_heights[sort_idx] + peak_signal_heights = peak_signal_heights[sort_idx] + peak_trace_values = peak_trace_values[sort_idx] peak_proms = peak_proms[sort_idx] + peak_norm_proms = peak_norm_proms[sort_idx] + peak_norm_scales = peak_norm_scales[sort_idx] + peak_mad_sigmas = peak_mad_sigmas[sort_idx] + peak_auto_prominence_thresholds = peak_auto_prominence_thresholds[sort_idx] peak_widths_sec = peak_widths_sec[sort_idx] peak_auc = peak_auc[sort_idx] all_file_ids = [all_file_ids[i] for i in sort_idx] @@ -5389,31 +5677,115 @@ def _detect_signal_events(self) -> None: "peak_frequency_per_min": freq_per_min, "mean_inter_peak_interval_s": float(np.nanmean(ipi)) if ipi.size else np.nan, "mean_auc": float(np.nanmean(peak_auc)) if np.any(np.isfinite(peak_auc)) else np.nan, + "baseline_prominence_normalized": bool(normalize_amplitude), + "mad_auto_threshold_enabled": bool(auto_mad), } + if auto_mad: + metrics.update( + { + "mad_multiplier": float(mad_multiplier), + "mean_mad_noise_sigma": ( + float(np.nanmean(peak_mad_sigmas)) if np.any(np.isfinite(peak_mad_sigmas)) else np.nan + ), + "mean_auto_prominence_threshold": ( + float(np.nanmean(peak_auto_prominence_thresholds)) + if np.any(np.isfinite(peak_auto_prominence_thresholds)) + else np.nan + ), + "mad_threshold_files_with_estimate": float( + sum( + 1 + for stats in mad_threshold_by_file.values() + if np.isfinite(float(stats.get("noise_sigma", np.nan))) + ) + ), + } + ) + if normalize_amplitude: + metrics.update( + { + "mean_raw_amplitude": float(np.nanmean(peak_signal_heights)), + "median_raw_amplitude": float(np.nanmedian(peak_signal_heights)), + "mean_normalized_prominence": ( + float(np.nanmean(peak_norm_proms)) if np.any(np.isfinite(peak_norm_proms)) else np.nan + ), + "mean_baseline_prominence_scale": ( + float(np.nanmean(peak_norm_scales)) if np.any(np.isfinite(peak_norm_scales)) else np.nan + ), + "baseline_prominence_files_with_scale": float( + sum( + 1 + for stats in normalization_by_file.values() + if np.isfinite(float(stats.get("scale", np.nan))) + ) + ), + "baseline_prominence_files_with_peak_scale": float( + sum( + 1 + for stats in normalization_by_file.values() + if str(stats.get("scale_source", "")) == "baseline_peak_prominence" + ) + ), + "baseline_prominence_files_with_mad_fallback": float( + sum( + 1 + for stats in normalization_by_file.values() + if str(stats.get("scale_source", "")) == "mad_noise_fallback" + ) + ), + } + ) self.last_signal_events = { "peak_times_sec": peak_times, "peak_indices": peak_idx, "peak_heights": peak_heights, + "peak_signal_heights": peak_signal_heights, + "peak_trace_values": peak_trace_values, "peak_prominences": peak_proms, + "peak_normalized_prominences": peak_norm_proms, + "peak_baseline_prominence_scale": peak_norm_scales, + "peak_mad_noise_sigma": peak_mad_sigmas, + "peak_auto_prominence_threshold": peak_auto_prominence_thresholds, "peak_widths_sec": peak_widths_sec, "peak_auc": peak_auc, "file_ids": all_file_ids, "derived_metrics": metrics, + "normalization_by_file": normalization_by_file, + "mad_threshold_by_file": mad_threshold_by_file, + "noise_overlay_by_file": noise_overlay_by_file, "params": { "method": self.combo_signal_method.currentText(), "prominence": float(self.spin_peak_prominence.value()), + "auto_mad_threshold": bool(auto_mad), + "mad_multiplier": float(mad_multiplier), "min_height": float(self.spin_peak_height.value()), "min_distance_sec": float(self.spin_peak_distance.value()), "smooth_sigma_sec": float(self.spin_peak_smooth.value()), "baseline_mode": self.combo_peak_baseline.currentText(), "baseline_window_sec": float(self.spin_peak_baseline_window.value()), + "baseline_prominence_normalized": bool(normalize_amplitude), "rate_bin_sec": float(self.spin_peak_rate_bin.value()), "auc_half_window_sec": float(self.spin_peak_auc_window.value()), }, } - self.statusUpdate.emit(f"Detected {peak_times.size} peak(s).", 5000) + msg = f"Detected {peak_times.size} peak(s)." + if auto_mad: + auto_vals = [ + float(stats.get("auto_prominence", np.nan)) + for stats in mad_threshold_by_file.values() + if np.isfinite(float(stats.get("auto_prominence", np.nan))) + ] + if auto_vals: + msg += f" Auto-MAD prominence ~{float(np.nanmean(auto_vals)):.4g}." + else: + msg += " Auto-MAD estimate unavailable; manual prominence used." + if normalize_amplitude and not any(np.isfinite(float(s.get("scale", np.nan))) for s in normalization_by_file.values()): + msg += " Baseline prominence scale unavailable." + elif normalize_amplitude and any(str(s.get("scale_source", "")) == "mad_noise_fallback" for s in normalization_by_file.values()): + msg += " Normalization used MAD fallback for files without baseline peaks." + self.statusUpdate.emit(msg, 5000) self._refresh_signal_overlay() self._render_signal_event_plots() self._update_signal_metrics_table() @@ -5431,6 +5803,11 @@ def _render_signal_event_plots(self) -> None: peak_heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float) if peak_times.size == 0 or peak_heights.size == 0: return + metrics = self.last_signal_events.get("derived_metrics", {}) or {} + if bool(metrics.get("baseline_prominence_normalized", False)): + self.plot_peak_amp.setLabel("bottom", "Prominence-normalized amplitude") + else: + self.plot_peak_amp.setLabel("bottom", "Amplitude") def _bar_hist(plot: pg.PlotWidget, values: np.ndarray, color: Tuple[int, int, int]) -> None: vals = np.asarray(values, float) @@ -5461,10 +5838,13 @@ def _update_signal_metrics_table(self) -> None: if not self.last_signal_events: return metrics = self.last_signal_events.get("derived_metrics", {}) or {} + normalized = bool(metrics.get("baseline_prominence_normalized", False)) + amp_label = "mean amplitude (prom-norm)" if normalized else "mean amplitude" + med_amp_label = "median amplitude (prom-norm)" if normalized else "median amplitude" rows = [ ("number of peaks", metrics.get("number_of_peaks", np.nan)), - ("mean amplitude", metrics.get("mean_amplitude", np.nan)), - ("median amplitude", metrics.get("median_amplitude", np.nan)), + (amp_label, metrics.get("mean_amplitude", np.nan)), + (med_amp_label, metrics.get("median_amplitude", np.nan)), ("amplitude std", metrics.get("amplitude_std", np.nan)), ("mean prominence", metrics.get("mean_prominence", np.nan)), ("mean width at half prom (s)", metrics.get("mean_width_half_prom_s", np.nan)), @@ -5472,6 +5852,26 @@ def _update_signal_metrics_table(self) -> None: ("mean inter-peak interval (s)", metrics.get("mean_inter_peak_interval_s", np.nan)), ("mean AUC", metrics.get("mean_auc", np.nan)), ] + if normalized: + rows.extend( + [ + ("mean raw amplitude", metrics.get("mean_raw_amplitude", np.nan)), + ("mean normalized prominence", metrics.get("mean_normalized_prominence", np.nan)), + ("baseline prominence scale", metrics.get("mean_baseline_prominence_scale", np.nan)), + ("files with scale", metrics.get("baseline_prominence_files_with_scale", np.nan)), + ("files using baseline peaks", metrics.get("baseline_prominence_files_with_peak_scale", np.nan)), + ("files using MAD fallback", metrics.get("baseline_prominence_files_with_mad_fallback", np.nan)), + ] + ) + if bool(metrics.get("mad_auto_threshold_enabled", False)): + rows.extend( + [ + ("MAD multiplier", metrics.get("mad_multiplier", np.nan)), + ("mean MAD noise sigma", metrics.get("mean_mad_noise_sigma", np.nan)), + ("mean auto prominence", metrics.get("mean_auto_prominence_threshold", np.nan)), + ("files with MAD estimate", metrics.get("mad_threshold_files_with_estimate", np.nan)), + ] + ) for key, value in rows: r = self.tbl_signal_metrics.rowCount() self.tbl_signal_metrics.insertRow(r) @@ -5492,8 +5892,33 @@ def _export_signal_events_csv(self) -> None: self._remember_export_dir(out_dir) peak_times = np.asarray(self.last_signal_events.get("peak_times_sec", np.array([], float)), float) peak_heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float) + peak_signal_heights = np.asarray( + self.last_signal_events.get("peak_signal_heights", peak_heights), + float, + ) + peak_trace_values = np.asarray( + self.last_signal_events.get("peak_trace_values", peak_signal_heights), + float, + ) peak_proms = np.asarray(self.last_signal_events.get("peak_prominences", np.array([], float)), float) + peak_norm_proms = np.asarray( + self.last_signal_events.get("peak_normalized_prominences", np.full_like(peak_proms, np.nan)), + float, + ) + peak_norm_scales = np.asarray( + self.last_signal_events.get("peak_baseline_prominence_scale", np.full_like(peak_heights, np.nan)), + float, + ) + peak_mad_sigmas = np.asarray( + self.last_signal_events.get("peak_mad_noise_sigma", np.full_like(peak_heights, np.nan)), + float, + ) + peak_auto_prominence = np.asarray( + self.last_signal_events.get("peak_auto_prominence_threshold", np.full_like(peak_heights, np.nan)), + float, + ) peak_widths = np.asarray(self.last_signal_events.get("peak_widths_sec", np.array([], float)), float) + peak_auc = np.asarray(self.last_signal_events.get("peak_auc", np.array([], float)), float) file_ids = self.last_signal_events.get("file_ids", []) if peak_times.size == 0: return @@ -5504,7 +5929,22 @@ def _export_signal_events_csv(self) -> None: import csv with open(out_path, "w", newline="") as f: w = csv.writer(f) - w.writerow(["peak_time_sec", "height", "prominence", "width_sec", "file_id"]) + w.writerow( + [ + "peak_time_sec", + "height", + "prominence", + "width_sec", + "auc", + "trace_value", + "signal_height", + "normalized_prominence", + "baseline_prominence_scale", + "mad_noise_sigma", + "auto_prominence_threshold", + "file_id", + ] + ) for i in range(peak_times.size): fid = file_ids[i] if isinstance(file_ids, list) and i < len(file_ids) else "" w.writerow( @@ -5513,6 +5953,13 @@ def _export_signal_events_csv(self) -> None: float(peak_heights[i]) if i < peak_heights.size else np.nan, float(peak_proms[i]) if i < peak_proms.size else np.nan, float(peak_widths[i]) if i < peak_widths.size else np.nan, + float(peak_auc[i]) if i < peak_auc.size else np.nan, + float(peak_trace_values[i]) if i < peak_trace_values.size else np.nan, + float(peak_signal_heights[i]) if i < peak_signal_heights.size else np.nan, + float(peak_norm_proms[i]) if i < peak_norm_proms.size else np.nan, + float(peak_norm_scales[i]) if i < peak_norm_scales.size else np.nan, + float(peak_mad_sigmas[i]) if i < peak_mad_sigmas.size else np.nan, + float(peak_auto_prominence[i]) if i < peak_auto_prominence.size else np.nan, fid, ] ) @@ -5953,6 +6400,18 @@ def _compute_psth(self) -> None: self._update_metric_regions() self._update_status_strip() self._save_settings() + # Feed data to temporal modeling widget + try: + self.section_temporal.set_data( + processed_trials=self._processed, + psth_mat=mat_display, + psth_tvec=tvec, + event_times=self._last_events, + file_ids=self._all_file_ids, + per_file_mats=self._per_file_mats, + ) + except Exception: + pass except Exception as e: self.statusUpdate.emit(f"Postprocessing error: {e}", 5000) self._update_status_strip() @@ -6657,7 +7116,13 @@ def _save_signal_events_h5(self, parent: h5py.Group) -> None: "peak_times_sec", "peak_indices", "peak_heights", + "peak_signal_heights", + "peak_trace_values", "peak_prominences", + "peak_normalized_prominences", + "peak_baseline_prominence_scale", + "peak_mad_noise_sigma", + "peak_auto_prominence_threshold", "peak_widths_sec", "peak_auc", ): @@ -6665,6 +7130,16 @@ def _save_signal_events_h5(self, parent: h5py.Group) -> None: self._write_h5_str_list(group, "file_ids", [str(v) for v in self.last_signal_events.get("file_ids", []) or []]) self._write_h5_json(group, "derived_metrics_json", dict(self.last_signal_events.get("derived_metrics", {}) or {})) self._write_h5_json(group, "params_json", dict(self.last_signal_events.get("params", {}) or {})) + self._write_h5_json_any( + group, + "normalization_by_file_json", + dict(self.last_signal_events.get("normalization_by_file", {}) or {}), + ) + self._write_h5_json_any( + group, + "mad_threshold_by_file_json", + dict(self.last_signal_events.get("mad_threshold_by_file", {}) or {}), + ) def _load_signal_events_h5(self, parent: Optional[h5py.Group]) -> Optional[Dict[str, object]]: if parent is None: @@ -6683,12 +7158,20 @@ def _num(name: str) -> np.ndarray: "peak_times_sec": _num("peak_times_sec"), "peak_indices": _num("peak_indices"), "peak_heights": _num("peak_heights"), + "peak_signal_heights": _num("peak_signal_heights"), + "peak_trace_values": _num("peak_trace_values"), "peak_prominences": _num("peak_prominences"), + "peak_normalized_prominences": _num("peak_normalized_prominences"), + "peak_baseline_prominence_scale": _num("peak_baseline_prominence_scale"), + "peak_mad_noise_sigma": _num("peak_mad_noise_sigma"), + "peak_auto_prominence_threshold": _num("peak_auto_prominence_threshold"), "peak_widths_sec": _num("peak_widths_sec"), "peak_auc": _num("peak_auc"), "file_ids": self._read_h5_str_list(group, "file_ids"), "derived_metrics": self._read_h5_json(group, "derived_metrics_json"), "params": self._read_h5_json(group, "params_json"), + "normalization_by_file": self._read_h5_json_any(group, "normalization_by_file_json", {}), + "mad_threshold_by_file": self._read_h5_json_any(group, "mad_threshold_by_file_json", {}), } return out @@ -7188,10 +7671,7 @@ def _confirm_discard_current_project(self) -> bool: ) return ask == QtWidgets.QMessageBox.StandardButton.Yes - def _new_project(self) -> None: - if not self._confirm_discard_current_project(): - return - + def _reset_project_state(self) -> None: was_restoring = self._is_restoring_settings self._is_restoring_settings = True try: @@ -7199,6 +7679,7 @@ def _new_project(self) -> None: self._processed = [] self._behavior_sources = {} self._pending_project_recompute_from_current = False + self._dio_cache.clear() self.lbl_group.setText("(none)") self.lbl_beh.setText("(none)") self.lbl_behavior_msg.setText("") @@ -7220,6 +7701,16 @@ def _new_project(self) -> None: self._project_recovered_from_autosave = False self._clear_project_autosave_cache(delete_file=True) self._update_status_strip() + + def reset_for_new_preprocessing_project(self) -> None: + self._reset_project_state() + self.statusUpdate.emit("Cleared postprocessing project state.", 5000) + + def _new_project(self) -> None: + if not self._confirm_discard_current_project(): + return + + self._reset_project_state() self.statusUpdate.emit("Started a new postprocessing project.", 5000) def _import_project_source_paths(self, recent_paths: Dict[str, object]) -> bool: @@ -7417,14 +7908,18 @@ def _collect_settings(self) -> Dict[str, object]: "signal_file": self.combo_signal_file.currentText(), "signal_method": self.combo_signal_method.currentText(), "signal_prominence": float(self.spin_peak_prominence.value()), + "signal_auto_mad": self.cb_peak_auto_mad.isChecked(), + "signal_mad_multiplier": float(self.spin_peak_mad_multiplier.value()), "signal_height": float(self.spin_peak_height.value()), "signal_distance": float(self.spin_peak_distance.value()), "signal_smooth": float(self.spin_peak_smooth.value()), "signal_baseline": self.combo_peak_baseline.currentText(), "signal_baseline_window": float(self.spin_peak_baseline_window.value()), + "signal_norm_prominence": self.cb_peak_norm_prominence.isChecked(), "signal_rate_bin": float(self.spin_peak_rate_bin.value()), "signal_auc_window": float(self.spin_peak_auc_window.value()), "signal_overlay": self.cb_peak_overlay.isChecked(), + "signal_noise_overlay": self.cb_peak_noise_overlay.isChecked(), "behavior_analysis_name": self.combo_behavior_analysis.currentText(), "behavior_analysis_bin": float(self.spin_behavior_bin.value()), "behavior_analysis_aligned": self.cb_behavior_aligned.isChecked(), @@ -7520,6 +8015,11 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None: _set_combo(self.combo_signal_method, data.get("signal_method")) if "signal_prominence" in data: self.spin_peak_prominence.setValue(float(data["signal_prominence"])) + if "signal_auto_mad" in data: + self.cb_peak_auto_mad.setChecked(bool(data["signal_auto_mad"])) + if "signal_mad_multiplier" in data: + self.spin_peak_mad_multiplier.setValue(float(data["signal_mad_multiplier"])) + self._update_peak_auto_mad_enabled(queue=False) if "signal_height" in data: self.spin_peak_height.setValue(float(data["signal_height"])) if "signal_distance" in data: @@ -7529,12 +8029,16 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None: _set_combo(self.combo_peak_baseline, data.get("signal_baseline")) if "signal_baseline_window" in data: self.spin_peak_baseline_window.setValue(float(data["signal_baseline_window"])) + if "signal_norm_prominence" in data: + self.cb_peak_norm_prominence.setChecked(bool(data["signal_norm_prominence"])) if "signal_rate_bin" in data: self.spin_peak_rate_bin.setValue(float(data["signal_rate_bin"])) if "signal_auc_window" in data: self.spin_peak_auc_window.setValue(float(data["signal_auc_window"])) if "signal_overlay" in data: self.cb_peak_overlay.setChecked(bool(data["signal_overlay"])) + if "signal_noise_overlay" in data: + self.cb_peak_noise_overlay.setChecked(bool(data["signal_noise_overlay"])) _set_combo(self.combo_behavior_analysis, data.get("behavior_analysis_name")) if "behavior_analysis_bin" in data: self.spin_behavior_bin.setValue(float(data["behavior_analysis_bin"])) diff --git a/pyBer/gui_preprocessing.py b/pyBer/gui_preprocessing.py index c2451b7..01614f3 100644 --- a/pyBer/gui_preprocessing.py +++ b/pyBer/gui_preprocessing.py @@ -74,6 +74,12 @@ def _compact_combo(combo: QtWidgets.QComboBox, min_chars: int = 6) -> None: "Raw signal (465)": "output = filtered/resampled 465 signal", } +_DFF_OUTPUT_MODES = { + "dFF (non motion corrected)", + "dFF (motion corrected via subtraction)", + "dFF (motion corrected with fitted ref)", +} + def _system_locale() -> QtCore.QLocale: return QtCore.QLocale.system() @@ -1458,6 +1464,10 @@ def _build_help_texts(self) -> Dict[str, str]: "Choose one or more DIO channels to export.\n" "If none are checked, export uses the current overlay trigger." ), + "auto_export": ( + "When enabled, the Export button writes selected files directly beside their source raw data.\n" + "All available analog channels are exported with the current analysis parameters." + ), } def _show_help(self, key: str, title: str) -> None: @@ -1747,6 +1757,11 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self.btn_metadata.clicked.connect(self.metadataRequested.emit) self.btn_advanced.clicked.connect(self.advancedOptionsRequested.emit) + self.chk_auto_export = QtWidgets.QCheckBox("Use source folder + all channels") + self.chk_auto_export.setChecked(False) + self.chk_auto_export.setToolTip( + "Skip the folder picker and export all analog channels beside each source file." + ) self.chk_export_raw = QtWidgets.QCheckBox("Raw 465") self.chk_export_raw.setChecked(True) self.chk_export_iso = QtWidgets.QCheckBox("Isobestic 405") @@ -1765,9 +1780,20 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self.list_export_dio = CheckableListWidget() self.list_export_dio.setMaximumHeight(84) self.list_export_dio.setMinimumHeight(54) + self.list_export_outputs = CheckableListWidget() + self.list_export_outputs.setMaximumHeight(116) + self.list_export_outputs.setMinimumHeight(72) + self.list_export_outputs.setToolTip("Select one or more processed output traces to write during export.") self._pending_export_channel_names: List[str] = [] self._pending_export_trigger_names: List[str] = [] + self._pending_export_output_modes: List[str] = [self.combo_output.currentText()] + self._export_outputs_follow_current_mode = True + self.list_export_outputs.set_items( + [(mode, mode) for mode in OUTPUT_MODES], + checked_values=self._pending_export_output_modes, + ) self.list_export_dio.setEnabled(self.chk_export_dio.isChecked()) + self.list_export_outputs.setEnabled(self.chk_export_output.isChecked()) self.export_options_group = QtWidgets.QGroupBox("Export fields") export_form = QtWidgets.QFormLayout(self.export_options_group) @@ -1784,6 +1810,8 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: export_checks.addWidget(self.chk_export_baseline_sig, 2, 0) export_checks.addWidget(self.chk_export_baseline_ref, 2, 1) export_form.addRow(export_checks) + export_form.addRow(self._label_with_help("Auto export", "auto_export"), self.chk_auto_export) + export_form.addRow(self._label_with_help("Output traces", "output_mode"), self.list_export_outputs) export_form.addRow(self._label_with_help("AN channels", "export_analog_channels"), self.list_export_channels) export_form.addRow(self._label_with_help("DIO channels", "export_dio_channels"), self.list_export_dio) @@ -1837,6 +1865,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self._update_output_definition() self._update_output_controls() self._update_smoothing_controls(emit_signal=False) + self._update_auto_export_controls() self._update_section_summaries() def _update_artifact_enabled(self) -> None: @@ -2016,6 +2045,7 @@ def emit_noargs(*_args) -> None: self.spin_lam_y.valueChanged.connect(lambda *_: self._update_lambda_preview()) self.combo_output.currentIndexChanged.connect(lambda *_: self._update_output_definition()) self.combo_output.currentIndexChanged.connect(lambda *_: self._update_output_controls()) + self.combo_output.currentIndexChanged.connect(lambda *_: self._sync_export_outputs_to_current_mode()) self.combo_ref_fit.currentIndexChanged.connect(lambda *_: self._update_output_controls()) # Keep collapsed card summaries synchronized with current values. @@ -2031,11 +2061,16 @@ def emit_noargs(*_args) -> None: self.combo_smoothing.currentIndexChanged.connect(lambda *_: self._update_smoothing_controls()) self.cb_invert.stateChanged.connect(emit_noargs) self.cb_show_artifact_overlay.toggled.connect(lambda v: self.artifactOverlayToggled.emit(bool(v))) + self.chk_auto_export.toggled.connect(self._update_auto_export_controls) + self.chk_auto_export.toggled.connect(emit_noargs) self.chk_export_dio.toggled.connect(self.list_export_dio.setEnabled) + self.chk_export_output.toggled.connect(self.list_export_outputs.setEnabled) self.list_export_channels.changed.connect(self._on_export_channel_selection_changed) self.list_export_channels.changed.connect(emit_noargs) self.list_export_dio.changed.connect(self._on_export_trigger_selection_changed) self.list_export_dio.changed.connect(emit_noargs) + self.list_export_outputs.changed.connect(self._on_export_output_selection_changed) + self.list_export_outputs.changed.connect(emit_noargs) for cb in ( self.chk_export_raw, self.chk_export_iso, @@ -2054,6 +2089,9 @@ def _toggle_advanced_baseline(self) -> None: "Hide advanced baseline options" if not is_visible else "Show advanced baseline options" ) + def _update_auto_export_controls(self, *_args) -> None: + self.list_export_channels.setEnabled(not self.auto_export_enabled()) + def set_config_state_hooks( self, exporter: Optional[Callable[[], Dict[str, object]]], @@ -2070,16 +2108,23 @@ def export_selection(self) -> ExportSelection: dio=self.chk_export_dio.isChecked(), baseline_sig=self.chk_export_baseline_sig.isChecked(), baseline_ref=self.chk_export_baseline_ref.isChecked(), + output_modes=self.export_output_modes(), ) def export_selection_summary(self) -> str: selection = self.export_selection() parts: List[str] = [] - chans = self.export_channel_names() - if chans: - parts.append(f"ANx{len(chans)}") + if self.auto_export_enabled(): + n_channels = self.list_export_channels.count() + parts.append(f"ANx{n_channels}" if n_channels else "all AN") + parts.append("source folder") + else: + chans = self.export_channel_names() + if chans: + parts.append(f"ANx{len(chans)}") if selection.output: - parts.append("output") + modes = self.export_output_modes() + parts.append(f"outputsx{len(modes)}" if len(modes) > 1 else "output") if selection.dio: dio_names = self.export_trigger_names() parts.append(f"DIOx{len(dio_names)}" if dio_names else "DIO") @@ -2106,6 +2151,10 @@ def set_export_selection(self, selection: ExportSelection) -> None: self.chk_export_dio.setChecked(bool(selection.dio)) self.chk_export_baseline_sig.setChecked(bool(selection.baseline_sig)) self.chk_export_baseline_ref.setChecked(bool(selection.baseline_ref)) + if selection.output_modes: + self.set_export_output_modes(selection.output_modes, follow_current=False) + else: + self.set_export_output_modes([self.combo_output.currentText()], follow_current=True) def export_channel_names(self) -> List[str]: checked = self.list_export_channels.checked_values() @@ -2115,12 +2164,28 @@ def export_trigger_names(self) -> List[str]: checked = self.list_export_dio.checked_values() return checked or list(self._pending_export_trigger_names) + def export_output_modes(self) -> List[str]: + checked = self.list_export_outputs.checked_values() + modes = checked or list(self._pending_export_output_modes) + modes = [mode for mode in modes if mode in OUTPUT_MODES] + return modes or [self.combo_output.currentText()] + def _on_export_channel_selection_changed(self) -> None: self._pending_export_channel_names = self.list_export_channels.checked_values() def _on_export_trigger_selection_changed(self) -> None: self._pending_export_trigger_names = self.list_export_dio.checked_values() + def _on_export_output_selection_changed(self) -> None: + self._pending_export_output_modes = self.list_export_outputs.checked_values() + self._export_outputs_follow_current_mode = False + self._update_section_summaries() + + def _sync_export_outputs_to_current_mode(self) -> None: + if not getattr(self, "_export_outputs_follow_current_mode", True): + return + self.set_export_output_modes([self.combo_output.currentText()], follow_current=True) + def set_available_export_channels(self, channels: List[str], preferred: Optional[List[str]] = None) -> None: current = list(preferred if preferred is not None else self.export_channel_names()) items = [(name, name) for name in channels or [] if str(name or "").strip()] @@ -2143,6 +2208,20 @@ def set_export_trigger_names(self, trigger_names: List[str]) -> None: self._pending_export_trigger_names = names self.list_export_dio.set_checked_values(names) + def set_export_output_modes(self, output_modes: List[str], follow_current: bool = False) -> None: + modes = [str(mode or "").strip() for mode in output_modes or [] if str(mode or "").strip() in OUTPUT_MODES] + if not modes: + modes = [self.combo_output.currentText()] + self._pending_export_output_modes = modes + self._export_outputs_follow_current_mode = bool(follow_current) + self.list_export_outputs.set_checked_values(modes) + + def auto_export_enabled(self) -> bool: + return bool(self.chk_auto_export.isChecked()) + + def set_auto_export_enabled(self, enabled: bool) -> None: + self.chk_auto_export.setChecked(bool(enabled)) + def _save_config(self) -> None: """Save current preprocessing parameters to a JSON file.""" params = self.get_params() @@ -2416,8 +2495,10 @@ class PlotDashboard(QtWidgets.QWidget): def __init__(self, parent=None) -> None: super().__init__(parent) self._sync_guard = False + self._last_xrange: Optional[Tuple[float, float]] = None self._artifact_overlay_visible = True self._artifact_thresholds_visible = True + self._dio_overlay_visible = False self._plot_background_mode = "dark" self._plot_grid_visible = True self._artifact_regions: List[pg.LinearRegionItem] = [] @@ -2479,17 +2560,9 @@ def _build_ui(self) -> None: for w in (self.plot_raw, self.plot_proc, self.plot_out): _optimize_plot(w) - # Primary axis for 465 signal + # Raw traces share the same primary y-axis. self.curve_465 = self.plot_raw.plot(pen=pg.mkPen((80, 250, 160), width=1.3)) - - # Twin axis for 405 (isobestic) signal - share Y axis with 465 - self.plot_raw_pi = self.plot_raw.getPlotItem() - self.plot_raw_pi.showAxis("right") - self.plot_raw_pi.getAxis("right").setLabel("405 (isobestic)", color=(160, 120, 255)) - - # For true twin axis, we plot both curves on the same plot area - # The 405 curve will use the same Y axis scaling as 465 - self.curve_405 = self.plot_raw.plot(pen=pg.mkPen((160, 120, 255, 128), width=1.2)) # Alpha for isobestic + self.curve_405 = self.plot_raw.plot(pen=pg.mkPen((160, 120, 255, 128), width=1.2)) pen_env = pg.mkPen((240, 200, 90), width=1.0, style=QtCore.Qt.PenStyle.DashLine) self.curve_thr_hi = self.plot_raw.plot(pen=pen_env) @@ -2524,6 +2597,7 @@ def _build_ui(self) -> None: self.vb_dio_raw, self.curve_dio_raw = self._add_dio_axis(self.plot_raw, "A/D") self.vb_dio_proc, self.curve_dio_proc = self._add_dio_axis(self.plot_proc, "A/D") self.vb_dio_out, self.curve_dio_out = self._add_dio_axis(self.plot_out, "A/D") + self._set_dio_overlay_visible(False) self._align_plot_axis_layouts() self.lbl_log = QtWidgets.QLabel("") @@ -2605,7 +2679,7 @@ def _align_plot_axis_layouts(self) -> None: # Fixed side gutters keep the three view boxes the same width, so shared x ranges # produce vertically aligned ticks and grid lines across the stacked plots. left_width = 64 - right_width = 56 + right_width = 56 if self._dio_overlay_visible else 0 bottom_height = 24 for plot in (self.plot_raw, self.plot_proc, self.plot_out): pi = plot.getPlotItem() @@ -2624,6 +2698,34 @@ def _align_plot_axis_layouts(self) -> None: except Exception: pass + def _set_dio_overlay_visible(self, visible: bool, label: str = "A/D") -> None: + self._dio_overlay_visible = bool(visible) + for plot, vb, curve in ( + (self.plot_raw, self.vb_dio_raw, self.curve_dio_raw), + (self.plot_proc, self.vb_dio_proc, self.curve_dio_proc), + (self.plot_out, self.vb_dio_out, self.curve_dio_out), + ): + pi = plot.getPlotItem() + axis = pi.getAxis("right") + if axis is not None: + try: + axis.setLabel(label if self._dio_overlay_visible else "") + except Exception: + pass + try: + pi.showAxis("right", self._dio_overlay_visible) + except Exception: + pass + try: + vb.setVisible(self._dio_overlay_visible) + except Exception: + pass + try: + curve.setVisible(self._dio_overlay_visible) + except Exception: + pass + self._align_plot_axis_layouts() + def _add_dio_axis(self, plot: pg.PlotWidget, label: str): pi = plot.getPlotItem() vb = pg.ViewBox() @@ -2651,6 +2753,7 @@ def _emit_xrange_from_any(self, _vb, x_range) -> None: return try: x0, x1 = x_range + self._last_xrange = (float(x0), float(x1)) self.xRangeChanged.emit(float(x0), float(x1)) except Exception: pass @@ -2684,6 +2787,9 @@ def _on_thresholds_toggled(self, checked: bool) -> None: self.artifactThresholdsToggled.emit(bool(checked)) def set_xrange_all(self, x0: float, x1: float) -> None: + if not np.isfinite(x0) or not np.isfinite(x1) or float(x1) <= float(x0): + return + self._last_xrange = (float(x0), float(x1)) self._sync_guard = True try: self.plot_raw.setXRange(x0, x1, padding=0) @@ -2692,6 +2798,20 @@ def set_xrange_all(self, x0: float, x1: float) -> None: finally: self._sync_guard = False + def current_xrange(self) -> Optional[Tuple[float, float]]: + try: + (x0, x1), _ = self.plot_raw.getViewBox().viewRange() + if np.isfinite(x0) and np.isfinite(x1) and x1 > x0: + return (float(x0), float(x1)) + except Exception: + pass + if self._last_xrange is None: + return None + x0, x1 = self._last_xrange + if np.isfinite(x0) and np.isfinite(x1) and x1 > x0: + return (float(x0), float(x1)) + return None + def set_full_xrange(self, t: np.ndarray) -> None: if t is None or np.asarray(t).size < 2: return @@ -2915,6 +3035,7 @@ def _set_dio(self, t: np.ndarray, dio: Optional[np.ndarray], name: str = "") -> self.curve_dio_raw.setData([], []) self.curve_dio_proc.setData([], []) self.curve_dio_out.setData([], []) + self._set_dio_overlay_visible(False) return tt = np.asarray(t, float) @@ -2925,16 +3046,7 @@ def _set_dio(self, t: np.ndarray, dio: Optional[np.ndarray], name: str = "") -> self.curve_dio_raw.setData(tt, yy, connect="finite", skipFiniteCheck=True) self.curve_dio_proc.setData(tt, yy, connect="finite", skipFiniteCheck=True) self.curve_dio_out.setData(tt, yy, connect="finite", skipFiniteCheck=True) - - if name: - self.plot_raw.getPlotItem().getAxis("right").setLabel(f"A/D ({name})") - self.plot_proc.getPlotItem().getAxis("right").setLabel(f"A/D ({name})") - self.plot_out.getPlotItem().getAxis("right").setLabel(f"A/D ({name})") - else: - self.plot_raw.getPlotItem().getAxis("right").setLabel("A/D") - self.plot_proc.getPlotItem().getAxis("right").setLabel("A/D") - self.plot_out.getPlotItem().getAxis("right").setLabel("A/D") - self._align_plot_axis_layouts() + self._set_dio_overlay_visible(True, label=(f"A/D ({name})" if name else "A/D")) # -------------------- Compatibility API expected by main.py -------------------- @@ -3123,11 +3235,12 @@ def show_output(self, *args, **kwargs) -> None: n = min(t.size, y.size) t, y = t[:n], y[:n] + label = _first_not_none(kwargs, "label", "output_label", default="Output") + self.curve_out.setData(t, y, connect="finite", skipFiniteCheck=True) if not preserve_view: self.set_full_xrange(t) - label = _first_not_none(kwargs, "label", "output_label", default="Output") context = _first_not_none(kwargs, "output_context", "label_context", default="") title = f"Output: {label}" if not context else f"Output: {label} | {context}" self.plot_out.setTitle(title) @@ -3140,6 +3253,7 @@ def show_output(self, *args, **kwargs) -> None: def update_plots(self, processed: ProcessedTrial, preserve_view: bool = False) -> None: t = np.asarray(processed.time, float) + kept_xrange = self.current_xrange() if preserve_view else None self.show_raw( t, processed.raw_signal, processed.raw_reference, dio=processed.dio, dio_name=processed.dio_name, @@ -3161,4 +3275,6 @@ def update_plots(self, processed: ProcessedTrial, preserve_view: bool = False) - preserve_view=preserve_view, ) self._update_artifact_overlays(t, processed.raw_signal, processed.artifact_regions_sec) + if kept_xrange is not None: + self.set_xrange_all(*kept_xrange) diff --git a/pyBer/main.py b/pyBer/main.py index 47e0afe..18ab29e 100644 --- a/pyBer/main.py +++ b/pyBer/main.py @@ -17,6 +17,69 @@ import sys from typing import Callable, Dict, List, Optional, Tuple + +_DLL_DIR_HANDLES = [] + + +def _bootstrap_windows_conda_runtime() -> None: + if os.name != "nt": + return + + os.environ.setdefault("PYTHONNOUSERSITE", "1") + + try: + import site + user_site = os.path.normcase(os.path.abspath(site.getusersitepackages())) + except Exception: + user_site = "" + + appdata_python = "" + appdata = os.environ.get("APPDATA", "") + if appdata: + appdata_python = os.path.normcase(os.path.abspath(os.path.join(appdata, "Python"))) + + def _is_user_site_path(path: str) -> bool: + if not path: + return False + try: + norm = os.path.normcase(os.path.abspath(path)) + except Exception: + return False + if user_site and (norm == user_site or norm.startswith(user_site + os.sep)): + return True + return bool(appdata_python and (norm == appdata_python or norm.startswith(appdata_python + os.sep))) + + sys.path[:] = [path for path in sys.path if not _is_user_site_path(path)] + script_dir = os.path.dirname(os.path.abspath(__file__)) + if script_dir and script_dir not in sys.path: + sys.path.insert(0, script_dir) + + prefix = os.environ.get("CONDA_PREFIX") or sys.prefix + dll_dirs = [ + prefix, + os.path.join(prefix, "Library", "mingw-w64", "bin"), + os.path.join(prefix, "Library", "usr", "bin"), + os.path.join(prefix, "Library", "bin"), + os.path.join(prefix, "Scripts"), + ] + existing = [path for path in dll_dirs if path and os.path.isdir(path)] + + if hasattr(os, "add_dll_directory"): + for path in existing: + try: + _DLL_DIR_HANDLES.append(os.add_dll_directory(path)) + except Exception: + pass + + old_path = os.environ.get("PATH", "") + old_parts = [os.path.normcase(os.path.abspath(p)) for p in old_path.split(os.pathsep) if p] + prepend = [p for p in existing if os.path.normcase(os.path.abspath(p)) not in old_parts] + if prepend: + os.environ["PATH"] = os.pathsep.join(prepend + [old_path]) + + +_bootstrap_windows_conda_runtime() + from PySide6 import QtCore, QtGui, QtWidgets import pyqtgraph as pg from pyqtgraph.dockarea import DockArea, Dock @@ -24,6 +87,7 @@ from analysis_core import ( ExportSelection, + OUTPUT_MODES, PhotometryProcessor, ProcessingParams, LoadedDoricFile, @@ -37,6 +101,7 @@ zscore_median_std, safe_divide, _lowpass_sos, + coerce_time_value, ) from gui_preprocessing import ( FileQueuePanel, @@ -47,7 +112,9 @@ AdvancedOptionsDialog, ) from gui_postprocessing import PostProcessingPanel +from numeric_controls import install_spinbox_scrubbers from styles import ( + apply_app_palette, app_qss, _make_icon, _paint_database, @@ -2257,6 +2324,7 @@ def _export_preprocessing_ui_state_for_config(self) -> Dict[str, object]: "export_selection": self.param_panel.export_selection().to_dict(), "export_channel_names": self.param_panel.export_channel_names(), "export_trigger_names": self.param_panel.export_trigger_names(), + "auto_export_to_source_dir": bool(self.param_panel.auto_export_enabled()), "panel_layout": self._collect_panel_layout_payload(), } @@ -2311,6 +2379,8 @@ def _import_preprocessing_ui_state_from_config(self, ui_state: Dict[str, object] self.param_panel.set_export_channel_names(list(ui_state.get("export_channel_names") or [])) if "export_trigger_names" in ui_state: self.param_panel.set_export_trigger_names(list(ui_state.get("export_trigger_names") or [])) + if "auto_export_to_source_dir" in ui_state: + self.param_panel.set_auto_export_enabled(_to_bool(ui_state.get("auto_export_to_source_dir"), False)) self._update_export_summary_label() panel_layout = ui_state.get("panel_layout") if isinstance(panel_layout, dict): @@ -2585,6 +2655,12 @@ def _restore_settings(self) -> None: self.plots.set_artifact_thresholds_visible(bool(show_thresholds)) except Exception: pass + try: + auto_export = _to_bool(self.settings.value("auto_export_to_source_dir", False), False) + self.param_panel.set_auto_export_enabled(auto_export) + self._update_export_summary_label() + except Exception: + pass try: default_bg = "white" if self._app_theme_mode == "light" else "dark" plot_bg = self.settings.value("pre_plot_background", default_bg, type=str) @@ -2704,6 +2780,10 @@ def _save_settings(self) -> None: self.settings.setValue("artifact_thresholds_visible", bool(self.plots.artifact_thresholds_visible())) except Exception: pass + try: + self.settings.setValue("auto_export_to_source_dir", bool(self.param_panel.auto_export_enabled())) + except Exception: + pass try: self.settings.setValue("pre_plot_background", str(self.plots.plot_background_mode())) self.settings.setValue("pre_plot_grid", bool(self.plots.plot_grid_visible())) @@ -2882,6 +2962,10 @@ def _new_preprocessing_project(self) -> None: return self._clear_preprocessing_project_state() self._pre_project_path = None + try: + self.post_tab.reset_for_new_preprocessing_project() + except Exception: + pass self._show_status_message("Started a new preprocessing project.", 5000) def _keyed_regions_to_project( @@ -3083,6 +3167,10 @@ def _load_preprocessing_project_from_path(self, path: str) -> None: self._clear_preprocessing_project_state() self._pre_project_path = path + try: + self.post_tab.reset_for_new_preprocessing_project() + except Exception: + pass self._apply_preprocessing_config_payload(payload.get("preprocessing_config")) session_mapping = payload.get("csv_mapping_session") @@ -3174,7 +3262,6 @@ def _is_csv_time_column(self, value: object) -> bool: return norm in {"time", "t", "timestamp", "times", "timesec", "times", "timems"} or "timestamp" in norm def _parse_csv_float(self, value: object) -> float: - from pyBer.analysis_core import coerce_time_value text = str(value or "").strip() if not text or text.lower() in {"nan", "none", "null", "na"}: return np.nan @@ -4240,6 +4327,7 @@ def _apply_app_theme(self, theme_mode: object, persist: bool = True) -> None: self.act_app_theme_light.blockSignals(False) try: + apply_app_palette(QtWidgets.QApplication.instance(), mode) self.setStyleSheet(app_qss(mode)) except Exception: pass @@ -4761,6 +4849,13 @@ def _mask_arr(arr: Optional[np.ndarray]) -> Optional[np.ndarray]: processed.baseline_sig = _mask_arr(processed.baseline_sig) processed.baseline_ref = _mask_arr(processed.baseline_ref) processed.output = _mask_arr(processed.output) + if hasattr(processed, "outputs") and processed.outputs: + masked_outputs = {} + for label, values in processed.outputs.items(): + masked = _mask_arr(values) + if masked is not None: + masked_outputs[str(label)] = masked + processed.outputs = masked_outputs # Mask triggers too if requested by convention, but here we keep them as-is or NaN them if hasattr(processed, "triggers") and processed.triggers: @@ -5046,7 +5141,7 @@ def _add_manual_region_from_selector(self) -> None: self._manual_regions_by_key[key] = regs start_s, end_s = self._time_window_bounds() self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s)) - self._trigger_preview() + self._trigger_preview(preserve_view=True) def _add_manual_region_from_drag(self, t0: float, t1: float) -> None: if self._box_select_callback: @@ -5073,7 +5168,7 @@ def _clear_manual_regions_current(self) -> None: self._manual_exclude_by_key[key] = [] self._pending_box_region_by_key.pop(key, None) self.artifact_panel.set_regions([]) - self._trigger_preview() + self._trigger_preview(preserve_view=True) def _request_box_select(self, callback: Callable[[float, float], None]) -> None: self._box_select_callback = callback @@ -5132,7 +5227,7 @@ def _assign_pending_box_to_artifact(self) -> None: self._manual_regions_by_key[key] = regs start_s, end_s = self._time_window_bounds() self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s)) - self._trigger_preview() + self._trigger_preview(preserve_view=True) def _assign_pending_box_to_cut(self) -> None: region = self._consume_pending_box_region() @@ -5200,7 +5295,7 @@ def _contains(target: Tuple[float, float], arr: List[Tuple[float, float]]) -> bo prev_ignore = self._manual_exclude_by_key.get(key, []) self._manual_regions_by_key[key] = self._merge_regions_with_window(prev_manual, manual_add, start_s, end_s) self._manual_exclude_by_key[key] = self._merge_regions_with_window(prev_ignore, manual_ignore, start_s, end_s) - self._trigger_preview() + self._trigger_preview(preserve_view=True) def _toggle_artifacts_panel(self) -> None: if self._use_pg_dockarea_pre_layout: @@ -5310,6 +5405,56 @@ def _remember_export_dir(self, out_dir: str, origin_dir: str) -> None: except Exception: pass + def _process_trial_for_export( + self, + trial: LoadedTrial, + params: ProcessingParams, + export_selection: ExportSelection, + manual_regions_sec: List[Tuple[float, float]], + manual_exclude_regions_sec: List[Tuple[float, float]], + ) -> ProcessedTrial: + modes: List[str] = [] + if export_selection.output: + for mode in export_selection.output_modes or [params.output_mode]: + mode = str(mode or "").strip() + if mode in OUTPUT_MODES and mode not in modes: + modes.append(mode) + if not modes: + modes = [params.output_mode if params.output_mode in OUTPUT_MODES else OUTPUT_MODES[0]] + + primary = params.output_mode if params.output_mode in modes else modes[0] + ordered_modes = [primary] + [mode for mode in modes if mode != primary] + base_processed: Optional[ProcessedTrial] = None + outputs: Dict[str, np.ndarray] = {} + + for mode in ordered_modes: + mode_params = ProcessingParams.from_dict(params.to_dict()) + mode_params.output_mode = mode + processed = self.processor.process_trial( + trial=trial, + params=mode_params, + manual_regions_sec=manual_regions_sec, + manual_exclude_regions_sec=manual_exclude_regions_sec, + preview_mode=False, + ) + if base_processed is None: + base_processed = processed + if processed.output is not None: + outputs[str(processed.output_label or mode)] = np.asarray(processed.output, float) + + if base_processed is None: + fallback_params = ProcessingParams.from_dict(params.to_dict()) + base_processed = self.processor.process_trial( + trial=trial, + params=fallback_params, + manual_regions_sec=manual_regions_sec, + manual_exclude_regions_sec=manual_exclude_regions_sec, + preview_mode=False, + ) + if export_selection.output: + base_processed.outputs = outputs + return base_processed + def _export_selected_or_all(self) -> None: selected = self._selected_paths() if not selected: @@ -5317,28 +5462,41 @@ def _export_selected_or_all(self) -> None: if not selected: return + auto_export = bool(self.param_panel.auto_export_enabled()) origin_dir = self._export_origin_dir(selected) - start_dir = self._export_start_dir(selected) - out_dir = QtWidgets.QFileDialog.getExistingDirectory(self, "Select export folder", start_dir) - if not out_dir: - return - self._remember_export_dir(out_dir, origin_dir) + out_dir = "" + if not auto_export: + start_dir = self._export_start_dir(selected) + out_dir = QtWidgets.QFileDialog.getExistingDirectory(self, "Select export folder", start_dir) + if not out_dir: + return + self._remember_export_dir(out_dir, origin_dir) params = self.param_panel.get_params() export_selection = self.param_panel.export_selection() - export_channel_names = self.param_panel.export_channel_names() + export_channel_names = [] if auto_export else self.param_panel.export_channel_names() export_trigger_names = self.param_panel.export_trigger_names() - # Process/export each selected file, for the currently selected channel. + # Process/export each selected file. Auto export writes beside each source + # file and intentionally exports every analog channel with the same params. n_total = 0 + exported_dirs = set() for path in selected: doric = self._loaded_files.get(path) if not doric: continue - channels = [name for name in export_channel_names if name in doric.channels] - if not channels: - fallback = self._current_channel if (self._current_channel in doric.channels) else (doric.channels[0] if doric.channels else None) - channels = [fallback] if fallback else [] + if auto_export: + channels = list(doric.channels) + else: + channels = [name for name in export_channel_names if name in doric.channels] + if not channels: + fallback = self._current_channel if (self._current_channel in doric.channels) else (doric.channels[0] if doric.channels else None) + channels = [fallback] if fallback else [] + path_out_dir = out_dir + if auto_export: + path_out_dir = os.path.dirname(path) + if not path_out_dir or not os.path.isdir(path_out_dir): + path_out_dir = origin_dir if origin_dir and os.path.isdir(origin_dir) else os.getcwd() dio_names = [name for name in export_trigger_names if name in doric.trigger_by_name] if not export_selection.dio: dio_names = [None] @@ -5373,10 +5531,11 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None: stem = safe_stem_from_metadata(path, ch, meta) if suffix: stem = f"{stem}_{suffix}" - csv_path = os.path.join(out_dir, f"{stem}.csv") - h5_path = os.path.join(out_dir, f"{stem}.h5") + csv_path = os.path.join(path_out_dir, f"{stem}.csv") + h5_path = os.path.join(path_out_dir, f"{stem}.h5") export_processed_csv(csv_path, proc, metadata=meta, selection=export_selection) export_processed_h5(h5_path, proc, metadata=meta, selection=export_selection) + exported_dirs.add(path_out_dir) n_total += 1 try: @@ -5388,21 +5547,21 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None: if sec_trial is None: continue sec_params = ProcessingParams.from_dict(sec.get("params", {})) if isinstance(sec.get("params"), dict) else params - processed = self.processor.process_trial( + processed = self._process_trial_for_export( trial=sec_trial, params=sec_params, + export_selection=export_selection, manual_regions_sec=manual, manual_exclude_regions_sec=manual_exclude, - preview_mode=False, ) _export_one(processed, suffix=f"sec{i}_{s0:.2f}_{s1:.2f}") else: - processed = self.processor.process_trial( + processed = self._process_trial_for_export( trial=trial, params=params, + export_selection=export_selection, manual_regions_sec=manual, manual_exclude_regions_sec=manual_exclude, - preview_mode=False, ) _export_one(processed) except Exception as e: @@ -5411,7 +5570,16 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None: "Export error", f"Failed export:\n{path} [{ch}] [{primary_trigger or 'no DIO'}]\n\n{e}", ) - self._show_status_message(f"Export complete: {n_total} recording(s) written to {out_dir}") + if auto_export: + if len(exported_dirs) == 1: + target = next(iter(exported_dirs)) + elif exported_dirs: + target = f"{len(exported_dirs)} source folders" + else: + target = "source folders" + else: + target = out_dir + self._show_status_message(f"Export complete: {n_total} recording(s) written to {target}") # optional: update post tab list by loading exported results? (user can load later) @@ -5897,6 +6065,8 @@ def main() -> None: pg.setConfigOptions(antialias=False) smoke_test = str(os.environ.get("PYBER_SMOKE_TEST", "")).strip().lower() in {"1", "true", "yes", "on"} app = QtWidgets.QApplication([]) + apply_app_palette(app, "dark") + spinbox_scrubber = install_spinbox_scrubbers(app) icon_path = _pyber_icon_path() try: if os.path.isfile(icon_path): @@ -5917,6 +6087,7 @@ def main() -> None: except Exception: splash = None w = MainWindow() + spinbox_scrubber.scan(w) if smoke_test: try: diff --git a/pyBer/numeric_controls.py b/pyBer/numeric_controls.py new file mode 100644 index 0000000..1321445 --- /dev/null +++ b/pyBer/numeric_controls.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import Optional + +from PySide6 import QtCore, QtGui, QtWidgets + + +def _event_global_pos(event: QtCore.QEvent) -> QtCore.QPoint: + try: + return event.globalPosition().toPoint() # Qt 6 + except Exception: + try: + return event.globalPos() # Qt 5 compatibility for older shims + except Exception: + return QtCore.QPoint() + + +def _event_local_pos(event: QtCore.QEvent) -> QtCore.QPoint: + try: + return event.position().toPoint() # Qt 6 + except Exception: + try: + return event.pos() + except Exception: + return QtCore.QPoint() + + +class SpinBoxScrubber(QtCore.QObject): + """Turn spin boxes into arrowless, draggable numeric controls. + + Users can still click into the field and type exact values. Dragging left/right + changes the value using the spin box's native stepping, so existing signal + wiring and validation continue to work. + """ + + _CONFIGURED_PROP = "_pyber_spin_scrubber_configured" + + def __init__(self, parent: Optional[QtCore.QObject] = None) -> None: + super().__init__(parent) + self._press_spin: Optional[QtWidgets.QAbstractSpinBox] = None + self._press_pos = QtCore.QPoint() + self._press_local_pos = QtCore.QPoint() + self._last_steps = 0 + self._dragging = False + self._override_cursor = False + + def scan(self, root: QtCore.QObject) -> None: + if isinstance(root, QtWidgets.QAbstractSpinBox): + self._configure_spinbox(root) + if isinstance(root, QtWidgets.QWidget): + for spin in root.findChildren(QtWidgets.QAbstractSpinBox): + self._configure_spinbox(spin) + + def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool: + etype = event.type() + if etype == QtCore.QEvent.Type.Show: + self._configure_object_tree(obj) + + spin = self._spinbox_for_object(obj) + if spin is None: + return False + if etype in ( + QtCore.QEvent.Type.MouseButtonPress, + QtCore.QEvent.Type.Wheel, + QtCore.QEvent.Type.FocusIn, + QtCore.QEvent.Type.KeyPress, + QtCore.QEvent.Type.Show, + ): + self._configure_spinbox(spin) + elif not bool(spin.property(self._CONFIGURED_PROP)): + return False + + if etype == QtCore.QEvent.Type.MouseButtonPress: + if not self._left_button_event(event) or not spin.isEnabled(): + return False + self._press_spin = spin + self._press_pos = _event_global_pos(event) + self._press_local_pos = _event_local_pos(event) + self._last_steps = 0 + self._dragging = False + return False + + if etype == QtCore.QEvent.Type.MouseMove and self._press_spin is spin: + if not self._left_button_held(event): + return False + dx_global = _event_global_pos(event).x() - self._press_pos.x() + dx_local = _event_local_pos(event).x() - self._press_local_pos.x() + dx = dx_global if abs(dx_global) >= abs(dx_local) else dx_local + if not self._dragging: + if abs(dx) < 5: + return False + self._dragging = True + spin.setFocus(QtCore.Qt.FocusReason.MouseFocusReason) + self._set_override_cursor(QtCore.Qt.CursorShape.SizeHorCursor) + + steps = self._steps_from_drag(dx, event) + delta = steps - self._last_steps + if delta: + spin.stepBy(delta) + self._last_steps = steps + return True + + if etype == QtCore.QEvent.Type.MouseButtonRelease and self._press_spin is spin: + was_dragging = self._dragging + self._press_spin = None + self._last_steps = 0 + self._dragging = False + self._restore_override_cursor() + return was_dragging + + if etype == QtCore.QEvent.Type.Leave and self._press_spin is spin and not self._left_button_held(event): + self._press_spin = None + self._last_steps = 0 + self._dragging = False + self._restore_override_cursor() + + return False + + def _configure_object_tree(self, obj: QtCore.QObject) -> None: + spin = self._spinbox_for_object(obj) + if spin is not None: + self._configure_spinbox(spin) + return + if isinstance(obj, QtWidgets.QWidget): + for child in obj.findChildren(QtWidgets.QAbstractSpinBox): + self._configure_spinbox(child) + + def _configure_spinbox(self, spin: QtWidgets.QAbstractSpinBox) -> None: + if bool(spin.property(self._CONFIGURED_PROP)): + return + spin.setProperty(self._CONFIGURED_PROP, True) + spin.setButtonSymbols(QtWidgets.QAbstractSpinBox.ButtonSymbols.NoButtons) + spin.setKeyboardTracking(False) + spin.setAccelerated(True) + spin.setCursor(QtCore.Qt.CursorShape.SizeHorCursor) + line_edit = spin.lineEdit() + if line_edit is not None: + line_edit.setCursor(QtCore.Qt.CursorShape.SizeHorCursor) + line_edit.setTextMargins(1, 0, 1, 0) + if isinstance(spin, QtWidgets.QDoubleSpinBox): + try: + spin.setStepType(QtWidgets.QAbstractSpinBox.StepType.AdaptiveDecimalStepType) + except Exception: + pass + tip = spin.toolTip().strip() + scrub_tip = "Drag left/right to adjust. Type for an exact value. Shift = faster, Ctrl = finer." + if scrub_tip not in tip: + spin.setToolTip(f"{tip}\n{scrub_tip}" if tip else scrub_tip) + + def _spinbox_for_object(self, obj: QtCore.QObject) -> Optional[QtWidgets.QAbstractSpinBox]: + if isinstance(obj, QtWidgets.QAbstractSpinBox): + return obj + parent = obj.parent() if isinstance(obj, QtCore.QObject) else None + while parent is not None: + if isinstance(parent, QtWidgets.QAbstractSpinBox): + return parent + parent = parent.parent() + return None + + def _steps_from_drag(self, dx: int, event: QtCore.QEvent) -> int: + pixels_per_step = 12.0 + try: + mods = event.modifiers() + except Exception: + mods = QtCore.Qt.KeyboardModifier.NoModifier + if mods & QtCore.Qt.KeyboardModifier.ShiftModifier: + pixels_per_step = 5.0 + elif mods & QtCore.Qt.KeyboardModifier.ControlModifier: + pixels_per_step = 28.0 + return int(dx / pixels_per_step) + + def _left_button_event(self, event: QtCore.QEvent) -> bool: + try: + return event.button() == QtCore.Qt.MouseButton.LeftButton + except Exception: + return False + + def _left_button_held(self, event: QtCore.QEvent) -> bool: + try: + return bool(event.buttons() & QtCore.Qt.MouseButton.LeftButton) + except Exception: + return False + + def _set_override_cursor(self, cursor: QtCore.Qt.CursorShape) -> None: + if self._override_cursor: + return + QtWidgets.QApplication.setOverrideCursor(QtGui.QCursor(cursor)) + self._override_cursor = True + + def _restore_override_cursor(self) -> None: + if not self._override_cursor: + return + try: + QtWidgets.QApplication.restoreOverrideCursor() + except Exception: + pass + self._override_cursor = False + + +def install_spinbox_scrubbers(app: QtWidgets.QApplication) -> SpinBoxScrubber: + existing = getattr(app, "_pyber_spinbox_scrubber", None) + if isinstance(existing, SpinBoxScrubber): + return existing + scrubber = SpinBoxScrubber(app) + app.installEventFilter(scrubber) + setattr(app, "_pyber_spinbox_scrubber", scrubber) + return scrubber diff --git a/pyBer/styles.py b/pyBer/styles.py index 16ac299..430cd73 100644 --- a/pyBer/styles.py +++ b/pyBer/styles.py @@ -180,6 +180,30 @@ def _paint_paw(p, r, c): p.drawEllipse(QtCore.QPoint(cx + dx, cy - rh // 4), max(2, rw // 10), max(2, rh // 8)) +def _paint_temporal(p, r, c): + """Temporal modeling icon — sine wave over a grid with a regression line.""" + from PySide6 import QtCore, QtGui + import math + p.setPen(_pen(c, 1.6)); p.setBrush(QtCore.Qt.BrushStyle.NoBrush) + # Horizontal axis + cy = r.top() + int(r.height() * 0.6) + p.drawLine(r.left(), cy, r.right(), cy) + # Sine-like curve + path = QtGui.QPainterPath() + path.moveTo(r.left(), cy) + w = r.width() + for i in range(w + 1): + x = r.left() + i + y = cy - math.sin(i / w * 2.5 * math.pi) * (r.height() * 0.35) + path.lineTo(x, y) + p.setPen(_pen(c, 2.0)) + p.drawPath(path) + # Regression trend line (dashed) + p.setPen(_pen(QtGui.QColor(c).lighter(140), 1.4)) + p.drawLine(r.left() + 2, cy + int(r.height() * 0.15), + r.right() - 2, cy - int(r.height() * 0.25)) + + APP_QSS = r""" @@ -683,3 +707,58 @@ def app_qss(theme_mode: object) -> str: if mode in {"light", "white", "l", "w"}: return APP_QSS_LIGHT return APP_QSS + + +def apply_app_palette(app, theme_mode: object) -> None: + """Force a consistent Fusion palette before applying app QSS. + + Some Windows/native Qt styles partially ignore dark QSS for menus, popup + views, disabled controls, or newly-created widgets. Fusion plus an explicit + palette keeps the app theme independent from the host OS theme. + """ + from PySide6 import QtGui, QtWidgets + + if app is None: + return + + mode = str(theme_mode or "").strip().lower() + light = mode in {"light", "white", "l", "w"} + try: + QtWidgets.QApplication.setStyle("Fusion") + except Exception: + pass + + palette = QtGui.QPalette() + if light: + colors = { + QtGui.QPalette.ColorRole.Window: "#f6f8fc", + QtGui.QPalette.ColorRole.WindowText: "#1f2a37", + QtGui.QPalette.ColorRole.Base: "#ffffff", + QtGui.QPalette.ColorRole.AlternateBase: "#edf1f7", + QtGui.QPalette.ColorRole.ToolTipBase: "#ffffff", + QtGui.QPalette.ColorRole.ToolTipText: "#1f2a37", + QtGui.QPalette.ColorRole.Text: "#1f2a37", + QtGui.QPalette.ColorRole.Button: "#e7ecf4", + QtGui.QPalette.ColorRole.ButtonText: "#1f2a37", + QtGui.QPalette.ColorRole.BrightText: "#ffffff", + QtGui.QPalette.ColorRole.Highlight: "#378ef0", + QtGui.QPalette.ColorRole.HighlightedText: "#ffffff", + } + else: + colors = { + QtGui.QPalette.ColorRole.Window: "#1f2229", + QtGui.QPalette.ColorRole.WindowText: "#f3f5f8", + QtGui.QPalette.ColorRole.Base: "#1b2029", + QtGui.QPalette.ColorRole.AlternateBase: "#262b35", + QtGui.QPalette.ColorRole.ToolTipBase: "#262b35", + QtGui.QPalette.ColorRole.ToolTipText: "#f3f5f8", + QtGui.QPalette.ColorRole.Text: "#f3f5f8", + QtGui.QPalette.ColorRole.Button: "#2b303b", + QtGui.QPalette.ColorRole.ButtonText: "#f3f5f8", + QtGui.QPalette.ColorRole.BrightText: "#ffffff", + QtGui.QPalette.ColorRole.Highlight: "#378ef0", + QtGui.QPalette.ColorRole.HighlightedText: "#ffffff", + } + for role, color in colors.items(): + palette.setColor(role, QtGui.QColor(color)) + app.setPalette(palette) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py new file mode 100644 index 0000000..bc1c0ba --- /dev/null +++ b/pyBer/temporal_modeling.py @@ -0,0 +1,1446 @@ +# temporal_modeling.py +""" +Temporal Modeling module for pyBer post-processing. + +Provides two backends: + 1. ContinuousGLM – design-matrix approach with temporal basis functions + (raised-cosine, B-spline, FIR) and ridge/lasso regression. + 2. TrialFLMM – Functional Linear Mixed Model via the R *fastFMM* package + (Loewinger et al., 2024), called through rpy2. + +The TemporalModelingWidget is a PySide6 panel that is embedded in the +PostProcessingPanel side-rail/dock system. +""" +from __future__ import annotations + +import logging +import os +import traceback +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from PySide6 import QtCore, QtWidgets, QtGui +import pyqtgraph as pg + +_LOG = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional dependency detection +# --------------------------------------------------------------------------- + +def _check_rpy2() -> bool: + """Return True if rpy2 can talk to R and fastFMM is loadable.""" + try: + _init_r() + return True + except Exception: + return False + + +def _init_r(): + """Initialise rpy2 + R + fastFMM. Idempotent.""" + global _R_READY + if _R_READY: + return + # On Windows, R is often not on PATH. Try the standard install location. + r_home = os.environ.get("R_HOME", "") + if not r_home: + candidate = "C:/Program Files/R" + if os.path.isdir(candidate): + subs = sorted(os.listdir(candidate), reverse=True) + if subs: + r_home = os.path.join(candidate, subs[0]) + os.environ["R_HOME"] = r_home + if r_home: + bin_x64 = os.path.join(r_home, "bin", "x64") + if os.path.isdir(bin_x64): + try: + os.add_dll_directory(bin_x64) + except (OSError, AttributeError): + pass + # Ensure PATH includes R so child processes find R.dll + cur_path = os.environ.get("PATH", "") + if bin_x64 not in cur_path: + os.environ["PATH"] = bin_x64 + os.pathsep + cur_path + + r_libs_user = os.environ.get("R_LIBS_USER", "") + if not r_libs_user: + candidate_lib = os.path.expanduser("~/R/win-library") + if os.path.isdir(candidate_lib): + subs = sorted(os.listdir(candidate_lib), reverse=True) + if subs: + os.environ["R_LIBS_USER"] = os.path.join(candidate_lib, subs[0]) + + import rpy2.robjects as ro # noqa: F811 + from rpy2.robjects import r as R # noqa: F811 + + # Set .libPaths so R can find user-installed packages + user_lib = os.environ.get("R_LIBS_USER", "") + if user_lib and os.path.isdir(user_lib): + sys_lib = os.path.join(r_home, "library") if r_home else "" + paths = [p for p in (user_lib, sys_lib) if p and os.path.isdir(p)] + R(f'.libPaths(c({", ".join(repr(p.replace(chr(92), "/")) for p in paths)}))') + + R("library(fastFMM)") + _R_READY = True + + +_R_READY = False + + +# ============================================================================ +# 1. Continuous GLM backend +# ============================================================================ + +def _raised_cosine_basis(n_basis: int, n_samples: int, + peak_range: Tuple[float, float] = (0.0, 1.0)) -> np.ndarray: + """Create a raised-cosine basis set (n_samples x n_basis).""" + peaks = np.linspace(peak_range[0], peak_range[1], n_basis) + width = (peak_range[1] - peak_range[0]) / max(n_basis - 1, 1) * 1.5 + t = np.linspace(0, 1, n_samples) + B = np.zeros((n_samples, n_basis)) + for i, pk in enumerate(peaks): + phi = np.clip((t - pk) / width * np.pi, -np.pi, np.pi) + B[:, i] = 0.5 * (1 + np.cos(phi)) + return B + + +def _bspline_basis(n_basis: int, n_samples: int, degree: int = 3) -> np.ndarray: + """Create a B-spline basis (n_samples x n_basis) using scipy.""" + from scipy.interpolate import BSpline + t_eval = np.linspace(0, 1, n_samples) + n_internal = n_basis - degree + 1 + internal_knots = np.linspace(0, 1, n_internal + 2)[1:-1] + knots = np.concatenate([np.zeros(degree + 1), internal_knots, np.ones(degree + 1)]) + B = np.zeros((n_samples, n_basis)) + for i in range(n_basis): + coeffs = np.zeros(n_basis) + coeffs[i] = 1.0 + spl = BSpline(knots, coeffs, degree, extrapolate=False) + vals = spl(t_eval) + vals[np.isnan(vals)] = 0.0 + B[:, i] = vals + return B + + +def _fir_basis(n_basis: int, n_samples: int) -> np.ndarray: + """FIR (identity / boxcar) basis — one column per time bin.""" + step = max(1, n_samples // n_basis) + B = np.zeros((n_samples, n_basis)) + for i in range(n_basis): + lo = i * step + hi = min(lo + step, n_samples) + B[lo:hi, i] = 1.0 + return B + + +@dataclass +class GLMResult: + """Result container for a continuous GLM fit.""" + predictor_names: List[str] + kernels: Dict[str, np.ndarray] # predictor -> (n_kernel_samples,) + kernel_tvec: np.ndarray # time vector for kernel x-axis + y_pred: np.ndarray # predicted trace + y_actual: np.ndarray # actual trace + residuals: np.ndarray + r2: float + coefficients: np.ndarray # raw beta vector + design_matrix: np.ndarray + + +class ContinuousGLM: + """Build a design matrix from event times and fit a linear model.""" + + BASIS_TYPES = ("raised_cosine", "bspline", "fir") + REGULARIZATION = ("ridge", "lasso", "ols") + + def __init__(self): + self._result: Optional[GLMResult] = None + + @staticmethod + def build_design_matrix( + time: np.ndarray, + predictors: Dict[str, np.ndarray], + kernel_window: Tuple[float, float], + n_basis: int = 8, + basis_type: str = "raised_cosine", + ) -> Tuple[np.ndarray, List[str], int]: + """ + Build a (T x P) design matrix from event times. + + Parameters + ---------- + time : (T,) array — continuous time vector + predictors : dict mapping predictor name -> 1-D array of event times + kernel_window : (pre, post) in seconds relative to event + n_basis : number of temporal basis functions + basis_type : 'raised_cosine', 'bspline', or 'fir' + + Returns + ------- + X : (T, n_predictors * n_basis) design matrix + col_names : column labels + n_basis : basis count (for later kernel extraction) + """ + dt = np.median(np.diff(time)) + pre_samp = int(round(abs(kernel_window[0]) / dt)) + post_samp = int(round(abs(kernel_window[1]) / dt)) + kernel_len = pre_samp + post_samp + + if basis_type == "bspline": + B = _bspline_basis(n_basis, kernel_len) + elif basis_type == "fir": + B = _fir_basis(n_basis, kernel_len) + else: + B = _raised_cosine_basis(n_basis, kernel_len) + + T = len(time) + col_names: List[str] = [] + X_parts: List[np.ndarray] = [] + + for pred_name, ev_times in predictors.items(): + ev_times = np.asarray(ev_times, float) + ev_times = ev_times[np.isfinite(ev_times)] + + # Convert event times to sample indices + ev_idx = np.searchsorted(time, ev_times) + ev_idx = ev_idx[(ev_idx >= 0) & (ev_idx < T)] + + # Build impulse vector + impulse = np.zeros(T, float) + for idx in ev_idx: + impulse[idx] = 1.0 + + # Convolve impulse with each basis function + part = np.zeros((T, n_basis), float) + for b in range(n_basis): + # Pad the basis to align with pre_samp offset + kernel = np.zeros(kernel_len) + kernel[:] = B[:, b] + conv = np.convolve(impulse, kernel, mode="full")[:T] + # Shift so that the kernel starts at -pre_samp + if pre_samp > 0: + part[:, b] = np.roll(conv, -pre_samp) + part[-pre_samp:, b] = 0.0 + else: + part[:, b] = conv + + X_parts.append(part) + for b in range(n_basis): + col_names.append(f"{pred_name}_b{b}") + + X = np.hstack(X_parts) if X_parts else np.zeros((T, 0)) + # Add intercept + X = np.column_stack([np.ones(T), X]) + col_names.insert(0, "intercept") + return X, col_names, n_basis + + def fit( + self, + time: np.ndarray, + signal: np.ndarray, + predictors: Dict[str, np.ndarray], + kernel_window: Tuple[float, float] = (-1.0, 3.0), + n_basis: int = 8, + basis_type: str = "raised_cosine", + regularization: str = "ridge", + alpha: float = 1.0, + ) -> GLMResult: + """Fit the GLM and return the result.""" + time = np.asarray(time, float) + signal = np.asarray(signal, float) + + X, col_names, n_b = self.build_design_matrix( + time, predictors, kernel_window, n_basis, basis_type, + ) + + # Mask NaN samples + valid = np.isfinite(signal) + Xv = X[valid] + yv = signal[valid] + + # Fit + if regularization == "lasso": + try: + from sklearn.linear_model import Lasso as _Lasso + model = _Lasso(alpha=alpha, max_iter=5000, fit_intercept=False) + model.fit(Xv, yv) + beta = model.coef_ + except ImportError: + _LOG.warning("sklearn not available; falling back to ridge") + beta = self._ridge_fit(Xv, yv, alpha) + elif regularization == "ridge": + beta = self._ridge_fit(Xv, yv, alpha) + else: # ols + beta, *_ = np.linalg.lstsq(Xv, yv, rcond=None) + + y_pred = X @ beta + residuals = signal - y_pred + ss_res = np.nansum(residuals ** 2) + ss_tot = np.nansum((signal - np.nanmean(signal)) ** 2) + r2 = 1.0 - ss_res / max(ss_tot, 1e-12) + + # Extract kernels + dt = np.median(np.diff(time)) + pre_samp = int(round(abs(kernel_window[0]) / dt)) + post_samp = int(round(abs(kernel_window[1]) / dt)) + kernel_len = pre_samp + post_samp + + if basis_type == "bspline": + B = _bspline_basis(n_b, kernel_len) + elif basis_type == "fir": + B = _fir_basis(n_b, kernel_len) + else: + B = _raised_cosine_basis(n_b, kernel_len) + + kernel_tvec = np.linspace(kernel_window[0], kernel_window[1], kernel_len) + kernels: Dict[str, np.ndarray] = {} + idx = 1 # skip intercept + for pred_name in predictors: + w = beta[idx:idx + n_b] + kernels[pred_name] = B @ w + idx += n_b + + self._result = GLMResult( + predictor_names=list(predictors.keys()), + kernels=kernels, + kernel_tvec=kernel_tvec, + y_pred=y_pred, + y_actual=signal, + residuals=residuals, + r2=r2, + coefficients=beta, + design_matrix=X, + ) + return self._result + + @staticmethod + def _ridge_fit(X: np.ndarray, y: np.ndarray, alpha: float) -> np.ndarray: + n = X.shape[1] + I = np.eye(n) + I[0, 0] = 0 # don't regularize intercept + return np.linalg.solve(X.T @ X + alpha * I, X.T @ y) + + +# ============================================================================ +# 2. Trial-level FLMM backend (via R fastFMM) +# ============================================================================ + +@dataclass +class FLMMResult: + """Result container for a trial-level FLMM fit.""" + tvec: np.ndarray # peri-event time vector + coefficients: Dict[str, np.ndarray] # term -> (n_time,) coefficient curve + ci_lower: Dict[str, np.ndarray] # term -> lower 95 % CI + ci_upper: Dict[str, np.ndarray] # term -> upper 95 % CI + joint_ci_lower: Dict[str, np.ndarray] + joint_ci_upper: Dict[str, np.ndarray] + residuals: Optional[np.ndarray] = None + aic: Optional[float] = None + summary_text: str = "" + + +class TrialFLMM: + """ + Functional Linear Mixed Model using R's fastFMM package. + + Wraps fastFMM::fui() via rpy2. The user provides a trial-level data + matrix (n_trials x n_timepoints) plus a design dataframe (n_trials rows) + with fixed/random predictors. The backend constructs the long-form + data and calls fui(). + """ + + def __init__(self): + self._result: Optional[FLMMResult] = None + self._available: Optional[bool] = None + + @property + def available(self) -> bool: + if self._available is None: + self._available = _check_rpy2() + return self._available + + def fit( + self, + mat: np.ndarray, + tvec: np.ndarray, + design: Dict[str, np.ndarray], + formula_fixed: str = "Y.obs ~ group", + random_effects: str = "~1", + group_var: str = "subject", + parallel: bool = False, + nknots_min: Optional[int] = None, + num_boots: int = 0, + ) -> FLMMResult: + """ + Fit a functional LMM via fastFMM::fui(). + + Parameters + ---------- + mat : (n_trials, n_timepoints) — the Y matrix (z-scored PSTH rows) + tvec : (n_timepoints,) — peri-event time + design : dict of predictor_name -> (n_trials,) arrays + Must include the grouping variable. + formula_fixed : R formula string for fixed effects + random_effects : R formula string for random effects + group_var : column name for the grouping/random-effects variable + parallel : whether to use parallelisation in fui() + nknots_min : minimum number of knots for penalised splines + num_boots : number of bootstrap iterations (0 = analytic only) + + Returns + ------- + FLMMResult + """ + _init_r() + import rpy2.robjects as ro + from rpy2.robjects import r as R, pandas2ri, numpy2ri + from rpy2.robjects.packages import importr + + n_trials, n_time = mat.shape + + # Build long-form dataframe in R + # Columns: Y.obs, .index (timepoint), .obs (trial id), + design vars + Y_long = mat.ravel(order="C") # trial-major + index_long = np.tile(np.arange(n_time), n_trials) + obs_long = np.repeat(np.arange(n_trials), n_time) + + r_df_vars = { + "Y.obs": ro.FloatVector(Y_long), + ".index": ro.IntVector(index_long.astype(int)), + ".obs": ro.IntVector(obs_long.astype(int)), + } + + for col_name, col_vals in design.items(): + col_vals = np.asarray(col_vals) + repeated = np.repeat(col_vals, n_time) + if np.issubdtype(col_vals.dtype, np.floating): + r_df_vars[col_name] = ro.FloatVector(repeated) + elif np.issubdtype(col_vals.dtype, np.integer): + r_df_vars[col_name] = ro.IntVector(repeated.astype(int)) + else: + r_df_vars[col_name] = ro.StrVector(repeated.astype(str)) + + r_df = ro.DataFrame(r_df_vars) + + # Call fui() + fastFMM = importr("fastFMM") + kwargs = { + "formula": ro.Formula(formula_fixed), + "data": r_df, + "id": ro.StrVector([group_var]), + "G": ro.Formula(random_effects), + "parallel": ro.BoolVector([parallel]), + } + if nknots_min is not None: + kwargs["nknots_min"] = ro.IntVector([nknots_min]) + if num_boots > 0: + kwargs["num_boots"] = ro.IntVector([num_boots]) + + _LOG.info("Calling fastFMM::fui() with formula=%s, %d trials, %d timepoints", + formula_fixed, n_trials, n_time) + + fui_result = fastFMM.fui(**kwargs) + + # Parse the result. + # fui() returns a list with elements: + # $betaHat — matrix (n_terms x n_time) of coefficient estimates + # $betaHat.LB / $betaHat.UB — pointwise 95% CI + # $betaHat.LB.joint / $betaHat.UB.joint — joint 95% CI + # $AIC + coefficients: Dict[str, np.ndarray] = {} + ci_lower: Dict[str, np.ndarray] = {} + ci_upper: Dict[str, np.ndarray] = {} + joint_ci_lower: Dict[str, np.ndarray] = {} + joint_ci_upper: Dict[str, np.ndarray] = {} + + try: + beta_hat = np.array(R('as.matrix')(fui_result.rx2("betaHat"))) + beta_lb = np.array(R('as.matrix')(fui_result.rx2("betaHat.LB"))) + beta_ub = np.array(R('as.matrix')(fui_result.rx2("betaHat.UB"))) + + # Term names from rownames + try: + term_names = list(R('rownames')(fui_result.rx2("betaHat"))) + except Exception: + term_names = [f"term_{i}" for i in range(beta_hat.shape[0])] + + for i, name in enumerate(term_names): + coefficients[name] = beta_hat[i, :] + ci_lower[name] = beta_lb[i, :] + ci_upper[name] = beta_ub[i, :] + + # Joint CIs (may not always be present) + try: + jlb = np.array(R('as.matrix')(fui_result.rx2("betaHat.LB.joint"))) + jub = np.array(R('as.matrix')(fui_result.rx2("betaHat.UB.joint"))) + for i, name in enumerate(term_names): + joint_ci_lower[name] = jlb[i, :] + joint_ci_upper[name] = jub[i, :] + except Exception: + joint_ci_lower = {k: v.copy() for k, v in ci_lower.items()} + joint_ci_upper = {k: v.copy() for k, v in ci_upper.items()} + + aic_val = None + try: + aic_val = float(np.array(fui_result.rx2("AIC"))[0]) + except Exception: + pass + + summary_parts = [f"FLMM fit: {len(term_names)} terms, {n_trials} trials, {n_time} timepoints"] + if aic_val is not None: + summary_parts.append(f"AIC = {aic_val:.1f}") + for name in term_names: + summary_parts.append(f" {name}: mean coef = {np.nanmean(coefficients[name]):.4f}") + summary_text = "\n".join(summary_parts) + + except Exception as exc: + _LOG.error("Failed to parse fui() result: %s", exc) + raise RuntimeError(f"fastFMM::fui() result parsing failed: {exc}") from exc + + self._result = FLMMResult( + tvec=tvec, + coefficients=coefficients, + ci_lower=ci_lower, + ci_upper=ci_upper, + joint_ci_lower=joint_ci_lower, + joint_ci_upper=joint_ci_upper, + aic=aic_val, + summary_text=summary_text, + ) + return self._result + + +# ============================================================================ +# 3. PySide6 Widget +# ============================================================================ + +_TEMPORAL_QSS = """ +TemporalModelingWidget { + background: #111821; + color: #d7e0ee; +} +QFrame#temporalHeader { + background: #101b2b; + border: 1px solid #263a52; + border-radius: 8px; +} +QFrame#temporalNav { + background: #0f1a28; + border: 1px solid #263a52; + border-radius: 8px; +} +QFrame#temporalControls { + background: #111d2c; + border: 1px solid #263a52; + border-radius: 8px; +} +QFrame#temporalWorkspace { + background: #111821; + border: 1px solid #263a52; + border-radius: 8px; +} +QLabel { + color: #d7e0ee; +} +QLabel[class="muted"] { + color: #9bacc3; +} +QLabel[class="title"] { + color: #eef4ff; + font-size: 15pt; + font-weight: 800; +} +QGroupBox { + color: #d7e0ee; + font-weight: 700; + border: 1px solid #29405c; + border-radius: 7px; + margin-top: 10px; + padding: 14px 10px 10px 10px; + background: #142033; +} +QGroupBox::title { + subcontrol-origin: margin; + subcontrol-position: top left; + left: 10px; + padding: 0 6px; + color: #dce8f8; +} +QComboBox, QSpinBox, QDoubleSpinBox, QLineEdit { + color: #e9f0fb; + background: #0f1724; + border: 1px solid #314963; + border-radius: 6px; + padding: 5px 8px; + min-height: 24px; +} +QComboBox::drop-down { + border: 0; + width: 24px; +} +QListWidget, QTextEdit { + color: #e6edf8; + background: #0d1420; + border: 1px solid #314963; + border-radius: 6px; + selection-background-color: #2d78c4; +} +QPushButton { + color: #eef4ff; + background: #17263a; + border: 1px solid #34506c; + border-radius: 7px; + padding: 6px 12px; + font-weight: 700; +} +QPushButton:hover { + background: #203450; +} +QPushButton[class="primary"] { + background: #2d8cff; + border: 1px solid #4ba0ff; +} +QToolButton { + color: #cdd8e8; + background: #101b2b; + border: 1px solid #29405c; + border-radius: 8px; + padding: 9px 8px; + font-weight: 700; +} +QToolButton:checked { + color: #ffffff; + background: #1f6db1; + border: 1px solid #35a4e8; +} +QTabWidget::pane { + border: 1px solid #263a52; + border-radius: 6px; + top: -1px; +} +QTabBar::tab { + color: #b9c8dc; + background: #152237; + border: 1px solid #263a52; + padding: 7px 14px; + border-top-left-radius: 6px; + border-top-right-radius: 6px; +} +QTabBar::tab:selected { + color: #ffffff; + background: #1f6db1; + border-color: #35a4e8; +} +""" +_SECTION_QSS = _TEMPORAL_QSS + + +class TemporalModelingWidget(QtWidgets.QWidget): + """ + PySide6 panel for Temporal Modeling (GLM / FLMM). + Embeddable in the PostProcessingPanel dock system. + """ + + statusMessage = QtCore.Signal(str, int) + + def __init__(self, parent: Optional[QtWidgets.QWidget] = None): + super().__init__(parent) + self._glm = ContinuousGLM() + self._flmm = TrialFLMM() + self._glm_result: Optional[GLMResult] = None + self._flmm_result: Optional[FLMMResult] = None + + # Data references (set by host panel) + self._processed_trials = [] + self._psth_mat: Optional[np.ndarray] = None + self._psth_tvec: Optional[np.ndarray] = None + self._event_times: Optional[np.ndarray] = None + self._file_ids: List[str] = [] + + self._build_compact_ui() + self._connect_signals() + + # ------------------------------------------------------------------ + # UI construction + # ------------------------------------------------------------------ + + def _build_ui(self): + root = QtWidgets.QVBoxLayout(self) + root.setContentsMargins(6, 6, 6, 6) + root.setSpacing(8) + + # ---- Model selector ---- + grp_model = QtWidgets.QGroupBox("Model type") + grp_model.setStyleSheet(_SECTION_QSS) + ml = QtWidgets.QVBoxLayout(grp_model) + ml.setSpacing(4) + + row_type = QtWidgets.QHBoxLayout() + self.combo_model_type = QtWidgets.QComboBox() + self.combo_model_type.addItems(["Continuous GLM", "Trial-level FLMM (fastFMM)"]) + row_type.addWidget(QtWidgets.QLabel("Type:")) + row_type.addWidget(self.combo_model_type, 1) + ml.addLayout(row_type) + + self.lbl_flmm_status = QtWidgets.QLabel("") + self.lbl_flmm_status.setProperty("class", "hint") + self.lbl_flmm_status.setWordWrap(True) + ml.addWidget(self.lbl_flmm_status) + root.addWidget(grp_model) + + # ---- GLM settings ---- + self.grp_glm = QtWidgets.QGroupBox("GLM settings") + self.grp_glm.setStyleSheet(_SECTION_QSS) + gl = QtWidgets.QFormLayout(self.grp_glm) + gl.setSpacing(4) + + self.combo_basis = QtWidgets.QComboBox() + self.combo_basis.addItems(["Raised cosine", "B-spline", "FIR"]) + gl.addRow("Basis:", self.combo_basis) + + self.spin_n_basis = QtWidgets.QSpinBox() + self.spin_n_basis.setRange(2, 50) + self.spin_n_basis.setValue(8) + gl.addRow("# basis:", self.spin_n_basis) + + self.combo_reg = QtWidgets.QComboBox() + self.combo_reg.addItems(["Ridge", "Lasso", "OLS"]) + gl.addRow("Regularization:", self.combo_reg) + + self.spin_alpha = QtWidgets.QDoubleSpinBox() + self.spin_alpha.setRange(0.001, 1000.0) + self.spin_alpha.setValue(1.0) + self.spin_alpha.setDecimals(3) + gl.addRow("Alpha (λ):", self.spin_alpha) + + self.spin_kernel_pre = QtWidgets.QDoubleSpinBox() + self.spin_kernel_pre.setRange(-30.0, 0.0) + self.spin_kernel_pre.setValue(-1.0) + self.spin_kernel_pre.setDecimals(1) + self.spin_kernel_pre.setSuffix(" s") + gl.addRow("Kernel pre:", self.spin_kernel_pre) + + self.spin_kernel_post = QtWidgets.QDoubleSpinBox() + self.spin_kernel_post.setRange(0.1, 60.0) + self.spin_kernel_post.setValue(3.0) + self.spin_kernel_post.setDecimals(1) + self.spin_kernel_post.setSuffix(" s") + gl.addRow("Kernel post:", self.spin_kernel_post) + + root.addWidget(self.grp_glm) + + # ---- FLMM settings ---- + self.grp_flmm = QtWidgets.QGroupBox("FLMM settings") + self.grp_flmm.setStyleSheet(_SECTION_QSS) + fl = QtWidgets.QFormLayout(self.grp_flmm) + fl.setSpacing(4) + + self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ group") + self.edit_formula.setPlaceholderText("e.g. Y.obs ~ group + condition") + fl.addRow("Fixed formula:", self.edit_formula) + + self.edit_random = QtWidgets.QLineEdit("~1") + self.edit_random.setPlaceholderText("e.g. ~1 or ~time") + fl.addRow("Random:", self.edit_random) + + self.edit_group_var = QtWidgets.QLineEdit("subject") + fl.addRow("Group var:", self.edit_group_var) + + self.spin_nknots = QtWidgets.QSpinBox() + self.spin_nknots.setRange(0, 100) + self.spin_nknots.setValue(0) + self.spin_nknots.setSpecialValueText("auto") + fl.addRow("Min knots:", self.spin_nknots) + + self.spin_boots = QtWidgets.QSpinBox() + self.spin_boots.setRange(0, 5000) + self.spin_boots.setValue(0) + self.spin_boots.setSpecialValueText("analytic") + fl.addRow("Bootstrap iter:", self.spin_boots) + + root.addWidget(self.grp_flmm) + + # ---- Predictor builder ---- + self.grp_predictors = QtWidgets.QGroupBox("Predictors") + self.grp_predictors.setStyleSheet(_SECTION_QSS) + pl = QtWidgets.QVBoxLayout(self.grp_predictors) + pl.setSpacing(4) + + self.list_predictors = QtWidgets.QListWidget() + self.list_predictors.setMaximumHeight(100) + pl.addWidget(self.list_predictors) + + pred_btn_row = QtWidgets.QHBoxLayout() + self.btn_add_predictor = QtWidgets.QPushButton("+ Add") + self.btn_add_predictor.setProperty("class", "compactSmall") + self.btn_remove_predictor = QtWidgets.QPushButton("- Remove") + self.btn_remove_predictor.setProperty("class", "compactSmall") + pred_btn_row.addWidget(self.btn_add_predictor) + pred_btn_row.addWidget(self.btn_remove_predictor) + pred_btn_row.addStretch(1) + pl.addLayout(pred_btn_row) + + self.lbl_predictor_hint = QtWidgets.QLabel( + "Predictors are auto-populated from DIO / behavior events when PSTH is computed." + ) + self.lbl_predictor_hint.setProperty("class", "hint") + self.lbl_predictor_hint.setWordWrap(True) + pl.addWidget(self.lbl_predictor_hint) + root.addWidget(self.grp_predictors) + + # ---- Fit button ---- + btn_row = QtWidgets.QHBoxLayout() + self.btn_fit = QtWidgets.QPushButton("Fit model") + self.btn_fit.setProperty("class", "compactPrimarySmall") + btn_row.addWidget(self.btn_fit) + btn_row.addStretch(1) + root.addLayout(btn_row) + + # ---- Results / summary ---- + self.grp_results = QtWidgets.QGroupBox("Model summary") + self.grp_results.setStyleSheet(_SECTION_QSS) + rl = QtWidgets.QVBoxLayout(self.grp_results) + rl.setSpacing(2) + self.txt_summary = QtWidgets.QTextEdit() + self.txt_summary.setReadOnly(True) + self.txt_summary.setMaximumHeight(140) + self.txt_summary.setStyleSheet("background: #1b2029; border: 1px solid #3a4050; border-radius: 4px;") + rl.addWidget(self.txt_summary) + root.addWidget(self.grp_results) + + # ---- Plot area ---- + self.grp_plots = QtWidgets.QGroupBox("Plots") + self.grp_plots.setStyleSheet(_SECTION_QSS) + plot_lay = QtWidgets.QVBoxLayout(self.grp_plots) + plot_lay.setSpacing(4) + + self.plot_kernel = pg.PlotWidget(title="Estimated kernels") + self.plot_kernel.setMinimumHeight(160) + self.plot_kernel.showGrid(x=True, y=True, alpha=0.25) + self.plot_kernel.addLegend(offset=(10, 10)) + plot_lay.addWidget(self.plot_kernel) + + self.plot_coeff = pg.PlotWidget(title="FLMM coefficient curves") + self.plot_coeff.setMinimumHeight(160) + self.plot_coeff.showGrid(x=True, y=True, alpha=0.25) + self.plot_coeff.addLegend(offset=(10, 10)) + plot_lay.addWidget(self.plot_coeff) + + root.addWidget(self.grp_plots) + root.addStretch(1) + + # Initial visibility + self._on_model_type_changed(0) + + def _build_compact_ui(self): + self.setStyleSheet(_TEMPORAL_QSS) + + root = QtWidgets.QVBoxLayout(self) + root.setContentsMargins(8, 8, 8, 8) + root.setSpacing(10) + + header = QtWidgets.QFrame() + header.setObjectName("temporalHeader") + h = QtWidgets.QHBoxLayout(header) + h.setContentsMargins(14, 10, 14, 10) + h.setSpacing(10) + + badge = QtWidgets.QLabel("T") + badge.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + badge.setFixedSize(34, 34) + badge.setStyleSheet( + "background: #2d8cff; color: white; border-radius: 17px; " + "font-weight: 800; font-size: 14pt;" + ) + h.addWidget(badge) + + title_col = QtWidgets.QVBoxLayout() + title_col.setContentsMargins(0, 0, 0, 0) + title_col.setSpacing(1) + title = QtWidgets.QLabel("Temporal Modeling") + title.setProperty("class", "title") + subtitle = QtWidgets.QLabel("Continuous GLM and trial-level FLMM analysis") + subtitle.setProperty("class", "muted") + title_col.addWidget(title) + title_col.addWidget(subtitle) + h.addLayout(title_col, 1) + + self.combo_model_type = QtWidgets.QComboBox() + self.combo_model_type.addItems(["Continuous GLM", "Trial-level FLMM (fastFMM)"]) + self.combo_model_type.setMinimumWidth(230) + h.addWidget(self.combo_model_type) + + self.btn_fit = QtWidgets.QPushButton("Fit model") + self.btn_fit.setProperty("class", "primary") + self.btn_fit.setMinimumWidth(120) + h.addWidget(self.btn_fit) + root.addWidget(header) + + split = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) + split.setChildrenCollapsible(False) + split.setHandleWidth(6) + root.addWidget(split, 1) + + left = QtWidgets.QFrame() + left.setObjectName("temporalControls") + left_lay = QtWidgets.QHBoxLayout(left) + left_lay.setContentsMargins(8, 8, 8, 8) + left_lay.setSpacing(10) + + nav = QtWidgets.QFrame() + nav.setObjectName("temporalNav") + nav_lay = QtWidgets.QVBoxLayout(nav) + nav_lay.setContentsMargins(8, 8, 8, 8) + nav_lay.setSpacing(8) + self.btn_nav_model = self._make_nav_button("Model") + self.btn_nav_predictors = self._make_nav_button("Predictors") + self.btn_nav_fit = self._make_nav_button("Fit") + for btn in (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_fit): + nav_lay.addWidget(btn) + nav_lay.addStretch(1) + left_lay.addWidget(nav) + + self.stack_controls = QtWidgets.QStackedWidget() + left_lay.addWidget(self.stack_controls, 1) + split.addWidget(left) + + workspace = QtWidgets.QFrame() + workspace.setObjectName("temporalWorkspace") + workspace_lay = QtWidgets.QVBoxLayout(workspace) + workspace_lay.setContentsMargins(10, 10, 10, 10) + workspace_lay.setSpacing(8) + self.tabs_workspace = QtWidgets.QTabWidget() + self.tabs_workspace.setDocumentMode(True) + workspace_lay.addWidget(self.tabs_workspace, 1) + split.addWidget(workspace) + split.setStretchFactor(0, 0) + split.setStretchFactor(1, 1) + split.setSizes([430, 1100]) + + self._build_model_page() + self._build_predictor_page() + self._build_fit_page() + self._build_workspace_pages() + + self.btn_nav_model.setChecked(True) + self.btn_nav_model.clicked.connect(lambda: self._select_control_page(0)) + self.btn_nav_predictors.clicked.connect(lambda: self._select_control_page(1)) + self.btn_nav_fit.clicked.connect(lambda: self._select_control_page(2)) + self.btn_fit_side.clicked.connect(self._on_fit_clicked) + self._on_model_type_changed(0) + + def _build_model_page(self): + page = QtWidgets.QWidget() + lay = QtWidgets.QVBoxLayout(page) + lay.setContentsMargins(0, 0, 0, 0) + lay.setSpacing(8) + + self.grp_glm = QtWidgets.QGroupBox("GLM Settings") + gl = QtWidgets.QFormLayout(self.grp_glm) + gl.setContentsMargins(12, 18, 12, 12) + gl.setHorizontalSpacing(10) + gl.setVerticalSpacing(8) + + self.combo_basis = QtWidgets.QComboBox() + self.combo_basis.addItems(["Raised cosine", "B-spline", "FIR"]) + gl.addRow("Basis", self.combo_basis) + + self.spin_n_basis = QtWidgets.QSpinBox() + self.spin_n_basis.setRange(2, 50) + self.spin_n_basis.setValue(8) + gl.addRow("Basis count", self.spin_n_basis) + + self.combo_reg = QtWidgets.QComboBox() + self.combo_reg.addItems(["Ridge", "Lasso", "OLS"]) + gl.addRow("Regularization", self.combo_reg) + + self.spin_alpha = QtWidgets.QDoubleSpinBox() + self.spin_alpha.setRange(0.001, 1000.0) + self.spin_alpha.setValue(1.0) + self.spin_alpha.setDecimals(3) + gl.addRow("Alpha", self.spin_alpha) + + self.spin_kernel_pre = QtWidgets.QDoubleSpinBox() + self.spin_kernel_pre.setRange(-30.0, 0.0) + self.spin_kernel_pre.setValue(-1.0) + self.spin_kernel_pre.setDecimals(1) + self.spin_kernel_pre.setSuffix(" s") + gl.addRow("Kernel pre", self.spin_kernel_pre) + + self.spin_kernel_post = QtWidgets.QDoubleSpinBox() + self.spin_kernel_post.setRange(0.1, 60.0) + self.spin_kernel_post.setValue(3.0) + self.spin_kernel_post.setDecimals(1) + self.spin_kernel_post.setSuffix(" s") + gl.addRow("Kernel post", self.spin_kernel_post) + lay.addWidget(self.grp_glm) + + self.grp_flmm = QtWidgets.QGroupBox("FLMM Settings") + fl = QtWidgets.QFormLayout(self.grp_flmm) + fl.setContentsMargins(12, 18, 12, 12) + fl.setHorizontalSpacing(10) + fl.setVerticalSpacing(8) + + self.lbl_flmm_status = QtWidgets.QLabel("") + self.lbl_flmm_status.setProperty("class", "muted") + self.lbl_flmm_status.setWordWrap(True) + fl.addRow("Backend", self.lbl_flmm_status) + + self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ group") + self.edit_formula.setPlaceholderText("e.g. Y.obs ~ group + condition") + fl.addRow("Fixed formula", self.edit_formula) + self.edit_random = QtWidgets.QLineEdit("~1") + self.edit_random.setPlaceholderText("e.g. ~1 or ~time") + fl.addRow("Random", self.edit_random) + self.edit_group_var = QtWidgets.QLineEdit("subject") + fl.addRow("Group var", self.edit_group_var) + self.spin_nknots = QtWidgets.QSpinBox() + self.spin_nknots.setRange(0, 100) + self.spin_nknots.setValue(0) + self.spin_nknots.setSpecialValueText("auto") + fl.addRow("Min knots", self.spin_nknots) + self.spin_boots = QtWidgets.QSpinBox() + self.spin_boots.setRange(0, 5000) + self.spin_boots.setValue(0) + self.spin_boots.setSpecialValueText("analytic") + fl.addRow("Bootstrap iter", self.spin_boots) + lay.addWidget(self.grp_flmm) + lay.addStretch(1) + self.stack_controls.addWidget(page) + + def _build_predictor_page(self): + page = QtWidgets.QWidget() + lay = QtWidgets.QVBoxLayout(page) + lay.setContentsMargins(0, 0, 0, 0) + lay.setSpacing(8) + + self.grp_predictors = QtWidgets.QGroupBox("Predictors") + pl = QtWidgets.QVBoxLayout(self.grp_predictors) + pl.setContentsMargins(12, 18, 12, 12) + pl.setSpacing(8) + + self.list_predictors = QtWidgets.QListWidget() + self.list_predictors.setMinimumHeight(220) + pl.addWidget(self.list_predictors) + + row = QtWidgets.QHBoxLayout() + self.btn_add_predictor = QtWidgets.QPushButton("+ Add") + self.btn_remove_predictor = QtWidgets.QPushButton("- Remove") + row.addWidget(self.btn_add_predictor) + row.addWidget(self.btn_remove_predictor) + row.addStretch(1) + pl.addLayout(row) + + self.lbl_predictor_hint = QtWidgets.QLabel( + "Predictors are populated from DIO or behavior events when PSTH is computed." + ) + self.lbl_predictor_hint.setProperty("class", "muted") + self.lbl_predictor_hint.setWordWrap(True) + pl.addWidget(self.lbl_predictor_hint) + lay.addWidget(self.grp_predictors, 1) + self.stack_controls.addWidget(page) + + def _build_fit_page(self): + page = QtWidgets.QWidget() + lay = QtWidgets.QVBoxLayout(page) + lay.setContentsMargins(0, 0, 0, 0) + lay.setSpacing(8) + + grp = QtWidgets.QGroupBox("Fit Control") + gl = QtWidgets.QVBoxLayout(grp) + gl.setContentsMargins(12, 18, 12, 12) + gl.setSpacing(10) + self.lbl_data_status = QtWidgets.QLabel("No PSTH data has been pushed yet.") + self.lbl_data_status.setProperty("class", "muted") + self.lbl_data_status.setWordWrap(True) + gl.addWidget(self.lbl_data_status) + self.btn_fit_side = QtWidgets.QPushButton("Fit model") + self.btn_fit_side.setProperty("class", "primary") + gl.addWidget(self.btn_fit_side) + gl.addStretch(1) + lay.addWidget(grp, 1) + self.stack_controls.addWidget(page) + + def _build_workspace_pages(self): + summary_page = QtWidgets.QWidget() + summary_lay = QtWidgets.QVBoxLayout(summary_page) + summary_lay.setContentsMargins(10, 10, 10, 10) + self.txt_summary = QtWidgets.QTextEdit() + self.txt_summary.setReadOnly(True) + summary_lay.addWidget(self.txt_summary, 1) + self.tabs_workspace.addTab(summary_page, "Summary") + + kernel_page = QtWidgets.QWidget() + kernel_lay = QtWidgets.QVBoxLayout(kernel_page) + kernel_lay.setContentsMargins(10, 10, 10, 10) + self.plot_kernel = pg.PlotWidget(title="Estimated kernels") + self._style_plot(self.plot_kernel) + kernel_lay.addWidget(self.plot_kernel, 1) + self.tabs_workspace.addTab(kernel_page, "Kernels") + + prediction_page = QtWidgets.QWidget() + prediction_lay = QtWidgets.QVBoxLayout(prediction_page) + prediction_lay.setContentsMargins(10, 10, 10, 10) + self.plot_prediction = pg.PlotWidget(title="Actual vs predicted") + self._style_plot(self.plot_prediction) + prediction_lay.addWidget(self.plot_prediction, 1) + self.tabs_workspace.addTab(prediction_page, "Prediction") + + residual_page = QtWidgets.QWidget() + residual_lay = QtWidgets.QVBoxLayout(residual_page) + residual_lay.setContentsMargins(10, 10, 10, 10) + self.plot_residuals = pg.PlotWidget(title="Residuals") + self._style_plot(self.plot_residuals) + residual_lay.addWidget(self.plot_residuals, 1) + self.tabs_workspace.addTab(residual_page, "Residuals") + + flmm_page = QtWidgets.QWidget() + flmm_lay = QtWidgets.QVBoxLayout(flmm_page) + flmm_lay.setContentsMargins(10, 10, 10, 10) + self.plot_coeff = pg.PlotWidget(title="FLMM coefficient curves") + self._style_plot(self.plot_coeff) + flmm_lay.addWidget(self.plot_coeff, 1) + self.tabs_workspace.addTab(flmm_page, "FLMM") + + def _make_nav_button(self, text: str) -> QtWidgets.QToolButton: + btn = QtWidgets.QToolButton() + btn.setText(text) + btn.setCheckable(True) + btn.setMinimumWidth(86) + btn.setMinimumHeight(54) + btn.setToolButtonStyle(QtCore.Qt.ToolButtonStyle.ToolButtonTextOnly) + return btn + + def _select_control_page(self, index: int) -> None: + self.stack_controls.setCurrentIndex(index) + buttons = (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_fit) + for i, btn in enumerate(buttons): + btn.setChecked(i == index) + + def _style_plot(self, plot: pg.PlotWidget) -> None: + plot.setMinimumHeight(360) + plot.setBackground("#05080d") + plot.showGrid(x=True, y=True, alpha=0.22) + plot.addLegend(offset=(12, 12)) + pi = plot.getPlotItem() + pi.getAxis("bottom").setPen(pg.mkPen("#516179")) + pi.getAxis("left").setPen(pg.mkPen("#516179")) + pi.getAxis("bottom").setTextPen(pg.mkPen("#c5d2e3")) + pi.getAxis("left").setTextPen(pg.mkPen("#c5d2e3")) + pi.titleLabel.item.setDefaultTextColor(QtGui.QColor("#d7e0ee")) + + # ------------------------------------------------------------------ + # Signal wiring + # ------------------------------------------------------------------ + + def _connect_signals(self): + self.combo_model_type.currentIndexChanged.connect(self._on_model_type_changed) + self.btn_fit.clicked.connect(self._on_fit_clicked) + self.btn_add_predictor.clicked.connect(self._on_add_predictor) + self.btn_remove_predictor.clicked.connect(self._on_remove_predictor) + + # ------------------------------------------------------------------ + # Public API — called by PostProcessingPanel + # ------------------------------------------------------------------ + + def set_data( + self, + processed_trials, + psth_mat: Optional[np.ndarray] = None, + psth_tvec: Optional[np.ndarray] = None, + event_times: Optional[np.ndarray] = None, + file_ids: Optional[List[str]] = None, + per_file_mats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, + ): + """Push data from the host panel into this widget.""" + self._processed_trials = processed_trials or [] + self._psth_mat = psth_mat + self._psth_tvec = psth_tvec + self._event_times = event_times + self._file_ids = file_ids or [] + self._per_file_mats = per_file_mats or {} + + # Auto-populate predictors for GLM + if self.list_predictors.count() == 0 and event_times is not None and len(event_times): + self.list_predictors.addItem("events") + n_trials = len(self._processed_trials) + psth_shape = tuple(np.shape(psth_mat)) if psth_mat is not None else None + bits = [f"Processed recordings: {n_trials}"] + if psth_shape: + bits.append(f"PSTH matrix: {psth_shape[0]} x {psth_shape[1]}") + if event_times is not None: + bits.append(f"Events: {len(event_times)}") + self.lbl_data_status.setText("\n".join(bits)) + + # ------------------------------------------------------------------ + # Slots + # ------------------------------------------------------------------ + + def _on_model_type_changed(self, index: int): + is_glm = (index == 0) + self.grp_glm.setVisible(is_glm) + self.grp_flmm.setVisible(not is_glm) + self.plot_kernel.setVisible(is_glm) + self.plot_coeff.setVisible(not is_glm) + if is_glm: + self.lbl_flmm_status.setText("") + if hasattr(self, "tabs_workspace"): + self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) + elif hasattr(self, "tabs_workspace"): + self.tabs_workspace.setCurrentWidget(self.plot_coeff.parentWidget()) + + # Check FLMM availability + if not is_glm: + if self._flmm.available: + self.lbl_flmm_status.setText("R + fastFMM detected.") + self.lbl_flmm_status.setStyleSheet("color: #6bdb74;") + else: + self.lbl_flmm_status.setText( + "R or fastFMM not found. Install R and the fastFMM package, " + "then install rpy2 (pip install rpy2)." + ) + self.lbl_flmm_status.setStyleSheet("color: #f5a97f;") + + def _on_add_predictor(self): + name, ok = QtWidgets.QInputDialog.getText( + self, "Add predictor", "Predictor name (must match a column in design):" + ) + if ok and name.strip(): + self.list_predictors.addItem(name.strip()) + + def _on_remove_predictor(self): + sel = self.list_predictors.currentRow() + if sel >= 0: + self.list_predictors.takeItem(sel) + + def _on_fit_clicked(self): + model_idx = self.combo_model_type.currentIndex() + try: + if model_idx == 0: + self._fit_glm() + else: + self._fit_flmm() + except Exception as exc: + _LOG.error("Temporal modeling fit failed: %s\n%s", exc, traceback.format_exc()) + self.txt_summary.setPlainText(f"Error: {exc}") + self.statusMessage.emit(f"Temporal model fit failed: {exc}", 8000) + + # ------------------------------------------------------------------ + # GLM fit + # ------------------------------------------------------------------ + + def _fit_glm(self): + if not self._processed_trials: + self.statusMessage.emit("No processed data — run preprocessing first.", 5000) + return + + proc = self._processed_trials[0] + time = np.asarray(proc.time, float) + signal = np.asarray(proc.output, float) if proc.output is not None else None + if signal is None or signal.size == 0: + self.statusMessage.emit("No output signal available.", 5000) + return + + # Build predictors from the list widget + predictors: Dict[str, np.ndarray] = {} + for i in range(self.list_predictors.count()): + pred_name = self.list_predictors.item(i).text() + if pred_name == "events" and self._event_times is not None: + predictors[pred_name] = self._event_times + elif proc.dio is not None and pred_name.lower() in ("dio", "digital"): + # Derive event times from DIO rising edges + dio = np.asarray(proc.dio, float) + edges = np.where(np.diff(dio > 0.5) == True)[0] # noqa: E712 + if edges.size > 0: + predictors[pred_name] = time[edges] + else: + # Try to find among triggers + if pred_name in (proc.triggers or {}): + trig = np.asarray(proc.triggers[pred_name], float) + edges = np.where(np.diff(trig > 0.5) == True)[0] # noqa: E712 + if edges.size > 0: + predictors[pred_name] = time[edges] + + if not predictors: + self.statusMessage.emit("No valid predictors found. Add event-based predictors.", 5000) + return + + basis_map = {"Raised cosine": "raised_cosine", "B-spline": "bspline", "FIR": "fir"} + reg_map = {"Ridge": "ridge", "Lasso": "lasso", "OLS": "ols"} + + kernel_win = (self.spin_kernel_pre.value(), self.spin_kernel_post.value()) + result = self._glm.fit( + time, signal, predictors, + kernel_window=kernel_win, + n_basis=self.spin_n_basis.value(), + basis_type=basis_map.get(self.combo_basis.currentText(), "raised_cosine"), + regularization=reg_map.get(self.combo_reg.currentText(), "ridge"), + alpha=self.spin_alpha.value(), + ) + self._glm_result = result + + # Summary + lines = [ + f"Continuous GLM — R² = {result.r2:.4f}", + f"Predictors: {', '.join(result.predictor_names)}", + f"Basis: {self.combo_basis.currentText()}, n={self.spin_n_basis.value()}", + f"Regularization: {self.combo_reg.currentText()}, α={self.spin_alpha.value():.3f}", + ] + self.txt_summary.setPlainText("\n".join(lines)) + + # Plot kernels + self._plot_glm_kernels(result) + self._plot_glm_fit(result) + if hasattr(self, "tabs_workspace"): + self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) + self.statusMessage.emit(f"GLM fit complete — R² = {result.r2:.4f}", 5000) + + def _plot_glm_kernels(self, result: GLMResult): + pw = self.plot_kernel + pw.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + colors = ["#4b9df8", "#f5a97f", "#6bdb74", "#ee99a0", "#c6a0f6", + "#f5e0dc", "#89dceb", "#fab387"] + for i, (name, kernel) in enumerate(result.kernels.items()): + color = colors[i % len(colors)] + pw.plot(result.kernel_tvec, kernel, pen=pg.mkPen(color, width=2), name=name) + pw.setLabel("bottom", "Time", units="s") + pw.setLabel("left", "Kernel weight") + # Zero line + pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + pw.addLine(x=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + + def _plot_glm_fit(self, result: GLMResult): + pw = self.plot_prediction + pw.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + x = np.arange(result.y_actual.size) + pw.plot(x, result.y_actual, pen=pg.mkPen("#4b9df8", width=1.2), name="actual") + pw.plot(x, result.y_pred, pen=pg.mkPen("#f5a97f", width=1.4), name="predicted") + pw.setLabel("bottom", "Sample") + pw.setLabel("left", "Signal") + + rw = self.plot_residuals + rw.clear() + rw.plot(x, result.residuals, pen=pg.mkPen("#ee99a0", width=1.1), name="residual") + rw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + rw.setLabel("bottom", "Sample") + rw.setLabel("left", "Residual") + + # ------------------------------------------------------------------ + # FLMM fit + # ------------------------------------------------------------------ + + def _fit_flmm(self): + if not self._flmm.available: + self.statusMessage.emit( + "R + fastFMM not available. Please install R, rpy2, and the fastFMM R package.", 8000 + ) + return + + if self._psth_mat is None or self._psth_tvec is None: + self.statusMessage.emit("No PSTH matrix — compute PSTH first.", 5000) + return + + mat = self._psth_mat + tvec = self._psth_tvec + if mat.ndim != 2 or mat.shape[0] < 2: + self.statusMessage.emit("Need at least 2 trials for FLMM.", 5000) + return + + n_trials = mat.shape[0] + + # Build the design dict from the predictor list and file_ids + design: Dict[str, np.ndarray] = {} + group_var = self.edit_group_var.text().strip() or "subject" + + # Default: use file IDs as subject labels if available + if self._file_ids and len(self._file_ids) > 0: + # In group mode, each row = one animal; in individual, each row = one trial + if len(self._file_ids) == n_trials: + design[group_var] = np.array(self._file_ids) + else: + # Per-trial: assign subject based on which file the trial came from + design[group_var] = np.array([f"subj_{i}" for i in range(n_trials)]) + else: + design[group_var] = np.array([f"subj_{i}" for i in range(n_trials)]) + + # Add any custom predictors from the list + for i in range(self.list_predictors.count()): + pred_name = self.list_predictors.item(i).text() + if pred_name == group_var or pred_name == "events": + continue + # Placeholder: user must supply these via design extensions + if pred_name not in design: + design[pred_name] = np.zeros(n_trials, float) + + formula = self.edit_formula.text().strip() or "Y.obs ~ 1" + random_eff = self.edit_random.text().strip() or "~1" + nknots = self.spin_nknots.value() if self.spin_nknots.value() > 0 else None + num_boots = self.spin_boots.value() + + self.statusMessage.emit("Fitting FLMM via fastFMM — this may take a while...", 0) + QtWidgets.QApplication.processEvents() + + result = self._flmm.fit( + mat, tvec, design, + formula_fixed=formula, + random_effects=random_eff, + group_var=group_var, + nknots_min=nknots, + num_boots=num_boots, + ) + self._flmm_result = result + + self.txt_summary.setPlainText(result.summary_text) + self._plot_flmm_coefficients(result) + self.statusMessage.emit("FLMM fit complete.", 5000) + + def _plot_flmm_coefficients(self, result: FLMMResult): + pw = self.plot_coeff + pw.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + + colors = ["#4b9df8", "#f5a97f", "#6bdb74", "#ee99a0", "#c6a0f6", + "#f5e0dc", "#89dceb", "#fab387"] + tvec = result.tvec + + for i, (name, coeff) in enumerate(result.coefficients.items()): + color = colors[i % len(colors)] + pen = pg.mkPen(color, width=2) + pw.plot(tvec, coeff, pen=pen, name=name) + + # Joint CI as filled region + if name in result.joint_ci_lower and name in result.joint_ci_upper: + ci_lo = result.joint_ci_lower[name] + ci_hi = result.joint_ci_upper[name] + fill_color = QtGui.QColor(color) + fill_color.setAlpha(40) + fill = pg.FillBetweenItem( + pg.PlotDataItem(tvec, ci_lo), + pg.PlotDataItem(tvec, ci_hi), + brush=fill_color, + ) + pw.addItem(fill) + + # Pointwise CI as dashed lines + if name in result.ci_lower and name in result.ci_upper: + dash_pen = pg.mkPen(color, width=1, style=QtCore.Qt.PenStyle.DashLine) + pw.plot(tvec, result.ci_lower[name], pen=dash_pen) + pw.plot(tvec, result.ci_upper[name], pen=dash_pen) + + pw.setLabel("bottom", "Time", units="s") + pw.setLabel("left", "Coefficient") + pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + pw.addLine(x=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) From 7d911f9f8a2260a44ab834d60711bb6403319fcb Mon Sep 17 00:00:00 2001 From: andrianj Date: Thu, 7 May 2026 18:23:57 +0200 Subject: [PATCH 02/14] Add undo redo and stabilize postprocessing plots --- pyBer/gui_postprocessing.py | 348 ++++++++++++++++++++++++++++++++++-- pyBer/gui_preprocessing.py | 32 ++++ pyBer/main.py | 188 ++++++++++++++++++- 3 files changed, 557 insertions(+), 11 deletions(-) diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index ed14494..0a68d7c 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -4,6 +4,7 @@ import os import re import json +import copy import logging from pathlib import Path from dataclasses import dataclass @@ -525,6 +526,7 @@ def __init__(self, parent=None) -> None: "heatmap_cmap": "viridis", "heatmap_min": None, "heatmap_max": None, + "heatmap_levels_manual": False, } self._section_popups: Dict[str, QtWidgets.QDockWidget] = {} self._section_scroll_hosts: Dict[str, QtWidgets.QScrollArea] = {} @@ -559,6 +561,13 @@ def __init__(self, parent=None) -> None: self._project_dirty: bool = False self._autosave_restoring: bool = False self._project_recovered_from_autosave: bool = False + self._suppress_heatmap_level_store: bool = False + self._history_undo: List[Dict[str, object]] = [] + self._history_redo: List[Dict[str, object]] = [] + self._history_current: Optional[Dict[str, object]] = None + self._history_key: str = "" + self._history_restoring: bool = False + self._history_limit: int = 60 try: self._build_ui() self._restore_settings() @@ -572,6 +581,7 @@ def __init__(self, parent=None) -> None: if app is not None: app.aboutToQuit.connect(self._on_about_to_quit) QtCore.QTimer.singleShot(0, self._restore_project_autosave_if_needed) + QtCore.QTimer.singleShot(0, self._reset_history_snapshot) def _build_ui(self) -> None: root = QtWidgets.QVBoxLayout(self) @@ -967,6 +977,9 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.btn_load_cfg = QtWidgets.QPushButton("Load config") self.btn_load_cfg.setProperty("class", "compactSmall") self.btn_load_cfg.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) + self.btn_reset_cfg = QtWidgets.QPushButton("Reset defaults") + self.btn_reset_cfg.setProperty("class", "compactSmall") + self.btn_reset_cfg.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) self.btn_new_project = QtWidgets.QPushButton("New project") self.btn_new_project.setProperty("class", "compactSmall") self.btn_new_project.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) @@ -1336,6 +1349,7 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): export_layout.addWidget(self.btn_export_img) export_layout.addWidget(self.btn_save_cfg) export_layout.addWidget(self.btn_load_cfg) + export_layout.addWidget(self.btn_reset_cfg) export_layout.addWidget(self.btn_new_project) export_layout.addWidget(self.btn_save_project) export_layout.addWidget(self.btn_load_project) @@ -1371,6 +1385,12 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.btn_action_compute.setProperty("class", "compactPrimarySmall") self.btn_action_export = QtWidgets.QPushButton("Export") self.btn_action_export.setProperty("class", "compactPrimarySmall") + self.btn_action_undo = QtWidgets.QPushButton("Undo") + self.btn_action_undo.setProperty("class", "compactSmall") + self.btn_action_undo.setToolTip("Undo last postprocessing setting or view action") + self.btn_action_redo = QtWidgets.QPushButton("Redo") + self.btn_action_redo.setProperty("class", "compactSmall") + self.btn_action_redo.setToolTip("Redo last undone postprocessing setting or view action") self.btn_action_hide = QtWidgets.QPushButton("Hide Panels") self.btn_action_hide.setProperty("class", "compactSmall") @@ -1441,8 +1461,10 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.btn_action_export.setText("Run Export") self.btn_action_hide.setText("Hide drawer") for b in (self.btn_action_load, self.btn_action_compute, - self.btn_action_export, self.btn_style): + self.btn_action_export, self.btn_action_undo, + self.btn_action_redo, self.btn_style): tb_layout.addWidget(b) + self._update_history_buttons() tb_layout.addStretch(1) tb_layout.addWidget(self.btn_action_hide) # action_row is no longer used; left in scope so any later reference @@ -1660,6 +1682,10 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): ) self.plot_metrics.addItem(self.metrics_err_pre) self.plot_metrics.addItem(self.metrics_err_post) + self.metrics_p_text = pg.TextItem("", color=(230, 236, 246), anchor=(0.5, 1.0)) + self.metrics_p_text.setZValue(20) + self.metrics_p_text.setVisible(False) + self.plot_metrics.addItem(self.metrics_p_text) self.plot_metrics.setXRange(-0.5, 1.5, padding=0) self.plot_metrics.getAxis("bottom").setTicks([[(0, "pre"), (1, "post")]]) @@ -1866,6 +1892,8 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.act_open_plot_style.triggered.connect(self._open_style_dialog) self.btn_action_compute.clicked.connect(self._compute_psth) self.btn_action_export.clicked.connect(self._export_results) + self.btn_action_undo.clicked.connect(self._undo_post_action) + self.btn_action_redo.clicked.connect(self._redo_post_action) self.btn_action_hide.clicked.connect(self._hide_all_section_popups) for key, btn in self._section_buttons.items(): btn.toggled.connect(lambda checked, section_key=key: self._toggle_section_popup(section_key, checked)) @@ -1897,6 +1925,7 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.btn_style.clicked.connect(self._open_style_dialog) self.btn_save_cfg.clicked.connect(self._save_config_file) self.btn_load_cfg.clicked.connect(self._load_config_file) + self.btn_reset_cfg.clicked.connect(self._reset_config_defaults) self.btn_new_project.clicked.connect(self._new_project) self.btn_save_project.clicked.connect(self._save_project_file) self.btn_load_project.clicked.connect(self._load_project_file) @@ -1917,7 +1946,9 @@ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b): self.cb_peak_norm_prominence.toggled.connect(lambda _checked=False: self._save_settings()) self.tab_sources.currentChanged.connect(self._refresh_signal_file_combo) self.tab_visual_mode.currentChanged.connect(self._on_visual_mode_changed) + self.tab_visual_mode.currentChanged.connect(self._queue_settings_save) self.combo_individual_file.currentIndexChanged.connect(self._on_individual_file_changed) + self.combo_individual_file.currentIndexChanged.connect(self._queue_settings_save) self.combo_align.currentIndexChanged.connect(self._update_align_ui) self.combo_behavior_file_type.currentIndexChanged.connect(self._update_align_ui) @@ -3684,9 +3715,112 @@ def _update_peak_auto_mad_enabled(self, _checked: object = None, *, queue: bool if queue: self._queue_settings_save() + def _history_snapshot(self) -> Dict[str, object]: + state = self._collect_settings() + try: + state["visual_mode"] = int(self.tab_visual_mode.currentIndex()) + except Exception: + state["visual_mode"] = 0 + try: + state["individual_file"] = self.combo_individual_file.currentText().strip() + except Exception: + state["individual_file"] = "" + return copy.deepcopy(state) + + def _history_state_key(self, state: Dict[str, object]) -> str: + try: + return json.dumps(state, sort_keys=True, separators=(",", ":"), default=str) + except Exception: + return repr(state) + + def _update_history_buttons(self) -> None: + for button, enabled in ( + (getattr(self, "btn_action_undo", None), bool(self._history_undo)), + (getattr(self, "btn_action_redo", None), bool(self._history_redo)), + ): + if button is not None: + button.setEnabled(enabled) + + def _reset_history_snapshot(self) -> None: + if not hasattr(self, "combo_align"): + return + self._history_undo.clear() + self._history_redo.clear() + self._history_current = self._history_snapshot() + self._history_key = self._history_state_key(self._history_current) + self._update_history_buttons() + + def _record_history_change(self) -> None: + if self._is_restoring_settings or self._history_restoring: + return + state = self._history_snapshot() + key = self._history_state_key(state) + if self._history_current is None: + self._history_current = copy.deepcopy(state) + self._history_key = key + self._update_history_buttons() + return + if key == self._history_key: + return + self._history_undo.append(copy.deepcopy(self._history_current)) + if len(self._history_undo) > self._history_limit: + self._history_undo = self._history_undo[-self._history_limit:] + self._history_redo.clear() + self._history_current = copy.deepcopy(state) + self._history_key = key + self._update_history_buttons() + + def _restore_history_state(self, state: Dict[str, object]) -> None: + was_restoring = self._is_restoring_settings + self._history_restoring = True + self._is_restoring_settings = True + try: + self._apply_settings(copy.deepcopy(state)) + if "visual_mode" in state: + idx = int(state.get("visual_mode") or 0) + if 0 <= idx < self.tab_visual_mode.count(): + self.tab_visual_mode.setCurrentIndex(idx) + if "individual_file" in state: + file_id = str(state.get("individual_file") or "").strip() + if file_id: + combo_idx = self.combo_individual_file.findText(file_id) + if combo_idx >= 0: + self.combo_individual_file.setCurrentIndex(combo_idx) + self._rerender_visual_from_cache() + self._update_trace_preview() + self._save_settings() + finally: + self._is_restoring_settings = was_restoring + self._history_restoring = False + + def _undo_post_action(self) -> None: + if not self._history_undo: + return + current = self._history_snapshot() + previous = self._history_undo.pop() + self._history_redo.append(copy.deepcopy(current)) + self._restore_history_state(previous) + self._history_current = copy.deepcopy(previous) + self._history_key = self._history_state_key(previous) + self._update_history_buttons() + self.statusUpdate.emit("Undid postprocessing action.", 2500) + + def _redo_post_action(self) -> None: + if not self._history_redo: + return + current = self._history_snapshot() + next_state = self._history_redo.pop() + self._history_undo.append(copy.deepcopy(current)) + self._restore_history_state(next_state) + self._history_current = copy.deepcopy(next_state) + self._history_key = self._history_state_key(next_state) + self._update_history_buttons() + self.statusUpdate.emit("Redid postprocessing action.", 2500) + def _queue_settings_save(self, *_args: object) -> None: if self._is_restoring_settings: return + self._record_history_change() if not self._autosave_restoring: self._project_dirty = True timer = getattr(self, "_settings_save_timer", None) @@ -6418,7 +6552,17 @@ def _compute_psth(self) -> None: def _render_heatmap(self, mat: np.ndarray, tvec: np.ndarray, labels: Optional[List[str]] = None) -> None: if mat.size == 0: - self.img.setImage(np.zeros((1, 1))) + self._suppress_heatmap_level_store = True + try: + self.img.setImage(np.zeros((1, 1)), autoLevels=False) + self.img.setLevels([0.0, 1.0]) + if hasattr(self, "heat_lut") and getattr(self.heat_lut, "item", None) is not None: + try: + self.heat_lut.item.setLevels(0.0, 1.0) + except TypeError: + self.heat_lut.item.setLevels((0.0, 1.0)) + finally: + self._suppress_heatmap_level_store = False self.heat_zero_line.setVisible(False) return @@ -6433,12 +6577,38 @@ def _render_heatmap(self, mat: np.ndarray, tvec: np.ndarray, labels: Optional[Li self.heat_lut.item.gradient.setColorMap(cmap) except Exception: pass + finite = img[np.isfinite(img)] + if finite.size: + lo = float(np.nanmin(finite)) + hi = float(np.nanmax(finite)) + else: + lo, hi = 0.0, 1.0 + manual_levels = bool(self._style.get("heatmap_levels_manual", False)) + if not manual_levels: + self._style["heatmap_min"] = None + self._style["heatmap_max"] = None hmin = self._style.get("heatmap_min", None) hmax = self._style.get("heatmap_max", None) - # set image (auto-level) - self.img.setImage(img, autoLevels=True) - if hmin is not None and hmax is not None: - self.img.setLevels([float(hmin), float(hmax)]) + if manual_levels and hmin is not None and hmax is not None: + try: + lo = float(hmin) + hi = float(hmax) + except Exception: + pass + if (not np.isfinite(lo)) or (not np.isfinite(hi)) or hi <= lo: + center = lo if np.isfinite(lo) else 0.0 + lo, hi = center - 0.5, center + 0.5 + self._suppress_heatmap_level_store = True + try: + self.img.setImage(img, autoLevels=False) + self.img.setLevels([lo, hi]) + if hasattr(self, "heat_lut") and getattr(self.heat_lut, "item", None) is not None: + try: + self.heat_lut.item.setLevels(lo, hi) + except TypeError: + self.heat_lut.item.setLevels((lo, hi)) + finally: + self._suppress_heatmap_level_store = False # Map image to time axis using a rect (avoids scale() incompatibilities) x0 = float(tvec[0]) if tvec.size else 0.0 @@ -6553,6 +6723,7 @@ def _render_metrics(self, mat: np.ndarray, tvec: np.ndarray) -> None: self.metrics_scatter_post.setData([], []) self._set_error_bar(self.metrics_err_pre, 0.0, 0.0, 0.0) self._set_error_bar(self.metrics_err_post, 1.0, 0.0, 0.0) + self.metrics_p_text.setVisible(False) self._last_metrics = None return metric = self.combo_metric.currentText() @@ -6577,6 +6748,7 @@ def _window_vals(a: float, b: float) -> np.ndarray: self.metrics_scatter_post.setData([], []) self._set_error_bar(self.metrics_err_pre, 0.0, 0.0, 0.0) self._set_error_bar(self.metrics_err_post, 1.0, 0.0, 0.0) + self.metrics_p_text.setVisible(False) self._last_metrics = None return @@ -6605,10 +6777,24 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray: pair_mask = np.isfinite(pre_vals_all) & np.isfinite(post_vals_all) pre_pair = pre_vals_all[pair_mask] post_pair = post_vals_all[pair_mask] + p_value = np.nan + n_pair = int(min(pre_pair.size, post_pair.size)) if pre_pair.size and post_pair.size: - n_pair = int(min(pre_pair.size, post_pair.size)) pre_pair = pre_pair[:n_pair] post_pair = post_pair[:n_pair] + if n_pair >= 2: + diffs = post_pair - pre_pair + diffs = diffs[np.isfinite(diffs)] + if diffs.size >= 2: + if float(np.nanstd(diffs, ddof=1)) == 0.0: + p_value = 1.0 if float(np.nanmean(diffs)) == 0.0 else 0.0 + else: + try: + from scipy import stats + result = stats.ttest_rel(pre_pair, post_pair, nan_policy="omit") + p_value = float(result.pvalue) + except Exception: + p_value = np.nan # Build segmented polyline: (0, pre_i) -> (1, post_i), NaN separator. x_line = np.empty(n_pair * 3, dtype=float) y_line = np.empty(n_pair * 3, dtype=float) @@ -6666,6 +6852,17 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray: if ymin == ymax: ymax = ymin + 1.0 self.plot_metrics.setYRange(ymin, ymax, padding=0.2) + span = float(ymax - ymin) if np.isfinite(ymax - ymin) and ymax != ymin else 1.0 + if n_pair >= 2: + if np.isfinite(p_value): + p_label = "paired p < 1e-4" if p_value < 1e-4 else f"paired p = {p_value:.4g}" + else: + p_label = "paired p = n/a" + self.metrics_p_text.setText(p_label) + self.metrics_p_text.setPos(0.5, ymax + 0.14 * span) + self.metrics_p_text.setVisible(True) + else: + self.metrics_p_text.setVisible(False) self._last_metrics = { "pre": pre_mean, "post": post_mean, @@ -6673,6 +6870,8 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray: "post_sem": post_sem, "pre_n": float(pre_n), "post_n": float(post_n), + "paired_n": float(n_pair), + "paired_p": float(p_value) if np.isfinite(p_value) else np.nan, "metric": metric, } @@ -6951,7 +7150,7 @@ def _apply_plot_style(self) -> None: pass def _on_heatmap_levels_changed(self) -> None: - if self._is_restoring_settings: + if self._is_restoring_settings or self._suppress_heatmap_level_store: return if not hasattr(self, "heat_lut") or getattr(self.heat_lut, "item", None) is None: return @@ -6966,6 +7165,7 @@ def _on_heatmap_levels_changed(self) -> None: if np.isfinite(lo) and np.isfinite(hi) and hi > lo: self._style["heatmap_min"] = lo self._style["heatmap_max"] = hi + self._style["heatmap_levels_manual"] = True self._queue_settings_save() def _open_style_dialog(self) -> None: @@ -6974,6 +7174,7 @@ def _open_style_dialog(self) -> None: return self._style = dlg.get_style() self._apply_plot_style() + self._record_history_change() self._render_heatmap(self._last_mat if self._last_mat is not None else np.zeros((1, 1)), self._last_tvec if self._last_tvec is not None else np.array([0.0, 1.0])) self._render_spatial_heatmap( self._last_spatial_occupancy_map, @@ -7704,6 +7905,7 @@ def _reset_project_state(self) -> None: def reset_for_new_preprocessing_project(self) -> None: self._reset_project_state() + self._reset_history_snapshot() self.statusUpdate.emit("Cleared postprocessing project state.", 5000) def _new_project(self) -> None: @@ -7711,6 +7913,7 @@ def _new_project(self) -> None: return self._reset_project_state() + self._reset_history_snapshot() self.statusUpdate.emit("Started a new postprocessing project.", 5000) def _import_project_source_paths(self, recent_paths: Dict[str, object]) -> bool: @@ -7834,8 +8037,116 @@ def _load_project_from_path(self, path: str, from_autosave: bool = False) -> boo self.statusUpdate.emit("Recovered autosaved postprocessing project.", 5000) else: self.statusUpdate.emit(f"Project loaded: {os.path.basename(path)}", 5000) + self._reset_history_snapshot() return True + def _default_settings_payload(self) -> Dict[str, object]: + return { + "align": "Behavior (CSV/XLSX)", + "dio_channel": self.combo_dio.currentText(), + "dio_polarity": "Event high (0->1)", + "dio_align": "Align to onset", + "behavior_file_type": "Binary states (time + 0/1 columns)", + "behavior_time_fps": 30.0, + "behavior": self.combo_behavior_name.currentText(), + "behavior_align": self.combo_behavior_align.currentText(), + "behavior_from": self.combo_behavior_from.currentText(), + "behavior_to": self.combo_behavior_to.currentText(), + "transition_gap": 1.0, + "window_pre": 2.0, + "window_post": 5.0, + "baseline_start": -1.0, + "baseline_end": 0.0, + "resample": 50.0, + "smooth": 0.0, + "filter_enabled": True, + "event_start": 1, + "event_end": 0, + "group_window_s": 0.0, + "dur_min": 0.0, + "dur_max": 0.0, + "metrics_enabled": True, + "metric": "AUC", + "metric_pre0": -1.0, + "metric_pre1": 0.0, + "metric_post0": 0.0, + "metric_post1": 1.0, + "global_metrics_enabled": True, + "global_start": 0.0, + "global_end": 0.0, + "global_amp": True, + "global_freq": True, + "view_layout": "Standard", + "visual_mode": 0, + "individual_file": self.combo_individual_file.currentText().strip(), + "signal_source": "Use processed output trace (loaded file)", + "signal_scope": "Per file", + "signal_file": self.combo_signal_file.currentText(), + "signal_method": "SciPy find_peaks", + "signal_prominence": 0.5, + "signal_auto_mad": False, + "signal_mad_multiplier": 5.0, + "signal_height": 0.0, + "signal_distance": 0.5, + "signal_smooth": 0.0, + "signal_baseline": "Use trace as-is", + "signal_baseline_window": 10.0, + "signal_norm_prominence": False, + "signal_rate_bin": 60.0, + "signal_auc_window": 0.5, + "signal_overlay": True, + "signal_noise_overlay": False, + "behavior_analysis_name": self.combo_behavior_analysis.currentText(), + "behavior_analysis_bin": 30.0, + "behavior_analysis_aligned": False, + "spatial_x": self.combo_spatial_x.currentText(), + "spatial_y": self.combo_spatial_y.currentText(), + "spatial_bins_x": 64, + "spatial_bins_y": 64, + "spatial_weight": "Occupancy (samples)", + "spatial_clip": True, + "spatial_clip_low": 1.0, + "spatial_clip_high": 99.0, + "spatial_time_filter": False, + "spatial_time_min": 0.0, + "spatial_time_max": 0.0, + "spatial_smooth": 0.0, + "spatial_activity_mode": "Mean z-score/bin (occupancy normalized)", + "spatial_activity_norm": True, + "spatial_log": False, + "spatial_invert_y": False, + "style": { + "trace": (90, 190, 255), + "behavior": (220, 180, 80), + "avg": (90, 190, 255), + "sem_edge": (152, 201, 143), + "sem_fill": (188, 230, 178, 96), + "plot_bg": (248, 250, 255) if self._app_theme_mode == "light" else (36, 42, 52), + "grid_enabled": True, + "grid_alpha": 0.25, + "heatmap_cmap": "viridis", + "heatmap_min": None, + "heatmap_max": None, + "heatmap_levels_manual": False, + }, + } + + def _reset_config_defaults(self) -> None: + previous_restoring = self._is_restoring_settings + previous_history_restoring = self._history_restoring + self._is_restoring_settings = True + self._history_restoring = True + try: + self._apply_settings(self._default_settings_payload()) + finally: + self._is_restoring_settings = previous_restoring + self._history_restoring = previous_history_restoring + self._record_history_change() + self._rerender_visual_from_cache() + self._compute_spatial_heatmap() + self._save_settings() + self.statusUpdate.emit("Reset postprocessing parameters to defaults.", 3000) + def _save_config_file(self) -> None: start_dir = self._settings.value("postprocess_last_dir", os.getcwd(), type=str) path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "Save config", os.path.join(start_dir, "postprocess_config.json"), "JSON (*.json)") @@ -7903,6 +8214,8 @@ def _collect_settings(self) -> Dict[str, object]: "global_amp": self.cb_global_amp.isChecked(), "global_freq": self.cb_global_freq.isChecked(), "view_layout": self.combo_view_layout.currentText(), + "visual_mode": int(self.tab_visual_mode.currentIndex()), + "individual_file": self.combo_individual_file.currentText().strip(), "signal_source": self.combo_signal_source.currentText(), "signal_scope": self.combo_signal_scope.currentText(), "signal_file": self.combo_signal_file.currentText(), @@ -8008,6 +8321,19 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None: if "global_freq" in data: self.cb_global_freq.setChecked(bool(data["global_freq"])) _set_combo(self.combo_view_layout, data.get("view_layout")) + if "visual_mode" in data: + try: + idx = int(data.get("visual_mode") or 0) + if 0 <= idx < self.tab_visual_mode.count(): + self.tab_visual_mode.setCurrentIndex(idx) + except Exception: + pass + if "individual_file" in data: + file_id = str(data.get("individual_file") or "").strip() + if file_id: + idx = self.combo_individual_file.findText(file_id) + if idx >= 0: + self.combo_individual_file.setCurrentIndex(idx) _set_combo(self.combo_signal_source, data.get("signal_source")) _set_combo(self.combo_signal_scope, data.get("signal_scope")) self._refresh_signal_file_combo() @@ -9796,8 +10122,10 @@ def _pick_color(self, key: str, with_alpha: bool = False) -> None: def get_style(self) -> Dict[str, object]: self._style["heatmap_cmap"] = self.combo_cmap.currentText() - self._style["heatmap_min"] = float(self.spin_hmin.value()) if self.spin_hmin.value() != 0.0 else None - self._style["heatmap_max"] = float(self.spin_hmax.value()) if self.spin_hmax.value() != 0.0 else None + manual_levels = self.spin_hmin.value() != 0.0 or self.spin_hmax.value() != 0.0 + self._style["heatmap_min"] = float(self.spin_hmin.value()) if manual_levels else None + self._style["heatmap_max"] = float(self.spin_hmax.value()) if manual_levels else None + self._style["heatmap_levels_manual"] = bool(manual_levels) self._style["grid_enabled"] = bool(self.cb_grid.isChecked()) self._style["grid_alpha"] = float(self.spin_grid_alpha.value()) self._style.setdefault("plot_bg", (36, 42, 52)) diff --git a/pyBer/gui_preprocessing.py b/pyBer/gui_preprocessing.py index 01614f3..aa891cd 100644 --- a/pyBer/gui_preprocessing.py +++ b/pyBer/gui_preprocessing.py @@ -1735,6 +1735,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self.btn_advanced = QtWidgets.QPushButton("Cutting / Sectioning") self.btn_save_config = QtWidgets.QPushButton("Save config") self.btn_load_config = QtWidgets.QPushButton("Load config") + self.btn_reset_defaults = QtWidgets.QPushButton("Reset defaults") for b in ( self.btn_artifacts_panel, self.btn_qc, @@ -1744,12 +1745,14 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self.btn_advanced, self.btn_save_config, self.btn_load_config, + self.btn_reset_defaults, ): b.setProperty("class", "compactSmall") b.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Fixed) self.btn_export.setProperty("class", "compactPrimarySmall") self.btn_save_config.clicked.connect(self._save_config) self.btn_load_config.clicked.connect(self._load_config) + self.btn_reset_defaults.clicked.connect(self._reset_defaults) self.btn_artifacts_panel.clicked.connect(self.artifactsRequested.emit) self.btn_qc.clicked.connect(self.qcRequested.emit) self.btn_qc_batch.clicked.connect(self.batchQcRequested.emit) @@ -1827,6 +1830,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: qc_grid.addWidget(self.btn_qc_batch, 2, 1) qc_grid.addWidget(self.btn_metadata, 3, 0) qc_grid.addWidget(self.btn_save_config, 3, 1) + qc_grid.addWidget(self.btn_reset_defaults, 4, 0) qc_grid.addWidget(self.btn_load_config, 4, 1) qc_grid.setColumnStretch(0, 1) qc_grid.setColumnStretch(1, 1) @@ -2298,6 +2302,19 @@ def _load_config(self) -> None: except Exception as e: QtWidgets.QMessageBox.warning(self, "Error", f"Failed to load config: {e}") + def _reset_defaults(self) -> None: + """Restore processing parameters and export toggles to application defaults.""" + self.set_params(ProcessingParams()) + self.cb_artifact.setChecked(True) + self.cb_filtering.setChecked(True) + self.cb_show_artifact_overlay.setChecked(True) + self.set_export_selection(ExportSelection()) + self.set_export_output_modes([self.combo_output.currentText()], follow_current=True) + self.chk_auto_export.setChecked(False) + self._update_auto_export_controls() + self._update_section_summaries() + self.paramsChanged.emit() + def _lambda_value(self) -> float: x = float(self.spin_lam_x.value()) y = int(self.spin_lam_y.value()) @@ -2485,6 +2502,8 @@ class PlotDashboard(QtWidgets.QWidget): manualRegionFromSelectorRequested = QtCore.Signal() manualRegionFromDragRequested = QtCore.Signal(float, float) clearManualRegionsRequested = QtCore.Signal() + undoRequested = QtCore.Signal() + redoRequested = QtCore.Signal() showArtifactsRequested = QtCore.Signal() boxSelectionCleared = QtCore.Signal() boxSelectionContextRequested = QtCore.Signal() @@ -2531,6 +2550,8 @@ def _build_ui(self) -> None: tools = QtWidgets.QHBoxLayout() self.btn_add_region = QtWidgets.QPushButton("Add from selector") self.btn_clear_regions = QtWidgets.QPushButton("Clear manual") + self.btn_undo = QtWidgets.QPushButton("Undo") + self.btn_redo = QtWidgets.QPushButton("Redo") self.btn_artifacts = QtWidgets.QPushButton("Artifacts") self.btn_box_select = QtWidgets.QPushButton("Box select") self.btn_box_select.setCheckable(True) @@ -2540,6 +2561,8 @@ def _build_ui(self) -> None: for b in ( self.btn_add_region, self.btn_clear_regions, + self.btn_undo, + self.btn_redo, self.btn_artifacts, self.btn_box_select, self.btn_thresholds, @@ -2547,6 +2570,8 @@ def _build_ui(self) -> None: b.setProperty("class", "compactSmall") tools.addWidget(self.btn_add_region) tools.addWidget(self.btn_clear_regions) + tools.addWidget(self.btn_undo) + tools.addWidget(self.btn_redo) tools.addWidget(self.btn_artifacts) tools.addWidget(self.btn_box_select) tools.addWidget(self.btn_thresholds) @@ -2614,6 +2639,8 @@ def _build_ui(self) -> None: self.btn_add_region.clicked.connect(self.manualRegionFromSelectorRequested.emit) self.btn_clear_regions.clicked.connect(self.clearManualRegionsRequested.emit) + self.btn_undo.clicked.connect(self.undoRequested.emit) + self.btn_redo.clicked.connect(self.redoRequested.emit) self.btn_artifacts.clicked.connect(self.showArtifactsRequested.emit) self.btn_box_select.toggled.connect(self._toggle_box_select) self.btn_thresholds.toggled.connect(self._on_thresholds_toggled) @@ -2629,8 +2656,13 @@ def _build_ui(self) -> None: self._sync_artifact_threshold_curves_visibility() self._toggle_box_select(False) + self.set_history_available(False, False) self.set_plot_appearance(self._plot_background_mode, self._plot_grid_visible) + def set_history_available(self, can_undo: bool, can_redo: bool) -> None: + self.btn_undo.setEnabled(bool(can_undo)) + self.btn_redo.setEnabled(bool(can_redo)) + def _normalize_plot_background_mode(self, value: object) -> str: mode = str(value or "").strip().lower() if mode in {"white", "light", "w"}: diff --git a/pyBer/main.py b/pyBer/main.py index 18ab29e..c009218 100644 --- a/pyBer/main.py +++ b/pyBer/main.py @@ -15,7 +15,7 @@ import json import logging import sys -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple _DLL_DIR_HANDLES = [] @@ -554,6 +554,12 @@ def __init__(self) -> None: self._pending_main_tab_index: Optional[int] = None self._force_fixed_dock_layouts: bool = bool(_FORCE_FIXED_DOCK_LAYOUTS) self._app_theme_mode: str = "dark" + self._pre_history_undo: List[Dict[str, Any]] = [] + self._pre_history_redo: List[Dict[str, Any]] = [] + self._pre_history_current: Optional[Dict[str, Any]] = None + self._pre_history_key: str = "" + self._pre_history_restoring: bool = False + self._pre_history_limit: int = 60 # Worker infra (stable) self._pool = QtCore.QThreadPool.globalInstance() @@ -577,6 +583,7 @@ def __init__(self) -> None: self._build_ui() self._restore_settings() self._panel_layout_persistence_ready = True + self._reset_pre_history_snapshot() # Enforce: preprocessing drawer is hidden until the user # explicitly clicks a rail section button (overrides any saved state). self._force_hide_pre_drawer_initially() @@ -933,6 +940,8 @@ def _build_ui(self) -> None: self.plots.manualRegionFromSelectorRequested.connect(self._add_manual_region_from_selector) self.plots.manualRegionFromDragRequested.connect(self._add_manual_region_from_drag) self.plots.clearManualRegionsRequested.connect(self._clear_manual_regions_current) + self.plots.undoRequested.connect(self._undo_pre_action) + self.plots.redoRequested.connect(self._redo_pre_action) self.plots.showArtifactsRequested.connect(self._toggle_artifacts_panel) self.plots.boxSelectionCleared.connect(self._cancel_box_select_request) self.plots.boxSelectionContextRequested.connect(self._show_box_selection_context_menu) @@ -1361,10 +1370,12 @@ def _build_config_actions_widget(self) -> QtWidgets.QWidget: self.param_panel.btn_metadata.setProperty("class", "blueSecondarySmall") self.param_panel.btn_save_config.setProperty("class", "blueSecondarySmall") self.param_panel.btn_load_config.setProperty("class", "blueSecondarySmall") + self.param_panel.btn_reset_defaults.setProperty("class", "blueSecondarySmall") for btn in ( self.param_panel.btn_metadata, self.param_panel.btn_save_config, self.param_panel.btn_load_config, + self.param_panel.btn_reset_defaults, ): btn.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed) btn.setMinimumWidth(90) @@ -2389,6 +2400,166 @@ def _import_preprocessing_ui_state_from_config(self, ui_state: Dict[str, object] self._save_panel_config_json() self._save_settings() + def _pre_history_clone(self, state: Dict[str, Any]) -> Dict[str, Any]: + try: + return json.loads(json.dumps(state)) + except Exception: + return dict(state) + + def _pre_history_state_key(self, state: Dict[str, Any]) -> str: + try: + return json.dumps(state, sort_keys=True, separators=(",", ":")) + except Exception: + return repr(state) + + def _pre_history_snapshot(self) -> Dict[str, Any]: + try: + params = self.param_panel.get_params().to_dict() + except Exception: + params = {} + start_s, end_s = self._time_window_bounds() + return { + "params": params, + "artifact_overlay_visible": bool(self.param_panel.artifact_overlay_visible()), + "artifact_thresholds_visible": bool(self.plots.artifact_thresholds_visible()), + "plot_background": self.plots.plot_background_mode(), + "plot_grid": bool(self.plots.plot_grid_visible()), + "export_selection": self.param_panel.export_selection().to_dict(), + "export_channel_names": self.param_panel.export_channel_names(), + "export_trigger_names": self.param_panel.export_trigger_names(), + "export_output_modes": self.param_panel.export_output_modes(), + "auto_export_to_source_dir": bool(self.param_panel.auto_export_enabled()), + "time_window": {"start_s": start_s, "end_s": end_s}, + "manual_regions": self._keyed_regions_to_project(self._manual_regions_by_key), + "manual_exclude_regions": self._keyed_regions_to_project(self._manual_exclude_by_key), + "cutout_regions": self._keyed_regions_to_project(self._cutout_regions_by_key), + "sections": self._sections_to_project(), + } + + def _update_pre_history_buttons(self) -> None: + try: + self.plots.set_history_available(bool(self._pre_history_undo), bool(self._pre_history_redo)) + except Exception: + pass + + def _reset_pre_history_snapshot(self) -> None: + self._pre_history_undo.clear() + self._pre_history_redo.clear() + self._pre_history_current = self._pre_history_snapshot() + self._pre_history_key = self._pre_history_state_key(self._pre_history_current) + self._update_pre_history_buttons() + + def _record_pre_history_change(self) -> None: + if self._pre_history_restoring: + return + state = self._pre_history_snapshot() + key = self._pre_history_state_key(state) + if self._pre_history_current is None: + self._pre_history_current = self._pre_history_clone(state) + self._pre_history_key = key + self._update_pre_history_buttons() + return + if key == self._pre_history_key: + return + self._pre_history_undo.append(self._pre_history_clone(self._pre_history_current)) + if len(self._pre_history_undo) > self._pre_history_limit: + self._pre_history_undo = self._pre_history_undo[-self._pre_history_limit:] + self._pre_history_redo.clear() + self._pre_history_current = self._pre_history_clone(state) + self._pre_history_key = key + self._update_pre_history_buttons() + + def _refresh_artifact_panel_for_current(self) -> None: + key = self._current_key() + if not key: + self.artifact_panel.set_auto_regions([]) + self.artifact_panel.set_regions([]) + return + start_s, end_s = self._time_window_bounds() + manual_win = self._clip_regions_to_window(self._manual_regions_by_key.get(key, []), start_s, end_s) + ignore_win = self._clip_regions_to_window(self._manual_exclude_by_key.get(key, []), start_s, end_s) + auto_win = self._clip_regions_to_window(self._auto_regions_by_key.get(key, []), start_s, end_s) + checked_auto = [r for r in auto_win if not any(self._regions_match(r, ig) for ig in ignore_win)] + self.artifact_panel.set_auto_regions(auto_win, checked_regions=checked_auto) + self.artifact_panel.set_regions(manual_win) + + def _restore_pre_history_state(self, state: Dict[str, Any]) -> None: + self._pre_history_restoring = True + try: + params = state.get("params") + if isinstance(params, dict): + self.param_panel.set_params(ProcessingParams.from_dict(params)) + if "artifact_overlay_visible" in state: + visible = bool(state.get("artifact_overlay_visible")) + self.param_panel.set_artifact_overlay_visible(visible) + self.plots.set_artifact_overlay_visible(visible) + if "artifact_thresholds_visible" in state: + self.plots.set_artifact_thresholds_visible(bool(state.get("artifact_thresholds_visible"))) + if "export_selection" in state: + self.param_panel.set_export_selection(ExportSelection.from_dict(state.get("export_selection"))) + if "export_output_modes" in state: + self.param_panel.set_export_output_modes(list(state.get("export_output_modes") or []), follow_current=False) + if "export_channel_names" in state: + self.param_panel.set_export_channel_names(list(state.get("export_channel_names") or [])) + if "export_trigger_names" in state: + self.param_panel.set_export_trigger_names(list(state.get("export_trigger_names") or [])) + if "auto_export_to_source_dir" in state: + self.param_panel.set_auto_export_enabled(_to_bool(state.get("auto_export_to_source_dir"), False)) + self._apply_pre_plot_style( + state.get("plot_background", self.plots.plot_background_mode()), + state.get("plot_grid", self.plots.plot_grid_visible()), + persist=False, + ) + self._manual_regions_by_key = self._project_to_keyed_regions(state.get("manual_regions")) + self._manual_exclude_by_key = self._project_to_keyed_regions(state.get("manual_exclude_regions")) + self._cutout_regions_by_key = self._project_to_keyed_regions(state.get("cutout_regions")) + self._sections_by_key = self._project_to_sections(state.get("sections")) + self._pending_box_region_by_key.clear() + tw = state.get("time_window") if isinstance(state.get("time_window"), dict) else {} + for ed, value in ( + (self.file_panel.edit_time_start, tw.get("start_s")), + (self.file_panel.edit_time_end, tw.get("end_s")), + ): + ed.blockSignals(True) + try: + ed.setText("" if value is None else f"{float(value):.6g}") + except Exception: + ed.setText("") + finally: + ed.blockSignals(False) + self._last_processed.clear() + self._refresh_artifact_panel_for_current() + self._update_export_summary_label() + self._update_raw_plot(preserve_view=True) + self._trigger_preview(preserve_view=True) + self._save_settings() + finally: + self._pre_history_restoring = False + + def _undo_pre_action(self) -> None: + if not self._pre_history_undo: + return + current = self._pre_history_snapshot() + previous = self._pre_history_undo.pop() + self._pre_history_redo.append(self._pre_history_clone(current)) + self._restore_pre_history_state(previous) + self._pre_history_current = self._pre_history_clone(previous) + self._pre_history_key = self._pre_history_state_key(previous) + self._update_pre_history_buttons() + self._show_status_message("Undid preprocessing action.", 2500) + + def _redo_pre_action(self) -> None: + if not self._pre_history_redo: + return + current = self._pre_history_snapshot() + next_state = self._pre_history_redo.pop() + self._pre_history_undo.append(self._pre_history_clone(current)) + self._restore_pre_history_state(next_state) + self._pre_history_current = self._pre_history_clone(next_state) + self._pre_history_key = self._pre_history_state_key(next_state) + self._update_pre_history_buttons() + self._show_status_message("Redid preprocessing action.", 2500) + def _sync_section_button_states_from_docks(self) -> None: if self._use_pg_dockarea_pre_layout: self._last_opened_section = None @@ -2966,6 +3137,7 @@ def _new_preprocessing_project(self) -> None: self.post_tab.reset_for_new_preprocessing_project() except Exception: pass + self._reset_pre_history_snapshot() self._show_status_message("Started a new preprocessing project.", 5000) def _keyed_regions_to_project( @@ -3224,6 +3396,7 @@ def _load_preprocessing_project_from_path(self, path: str) -> None: "Open project", "Some linked input files are missing and were skipped:\n" + "\n".join(missing_paths[:12]), ) + self._reset_pre_history_snapshot() self._show_status_message(f"Preprocessing project loaded: {os.path.basename(path)}", 5000) def _restore_file_selection(self, selected_paths: List[str], current_path: Optional[str]) -> None: @@ -4296,10 +4469,12 @@ def _on_main_tab_changed(self, index: int) -> None: def _on_artifact_overlay_toggled(self, visible: bool) -> None: self.plots.set_artifact_overlay_visible(bool(visible)) + self._record_pre_history_change() self._save_settings() def _on_artifact_thresholds_toggled(self, visible: bool) -> None: self.plots.set_artifact_thresholds_visible(bool(visible)) + self._record_pre_history_change() self._save_settings() def _normalize_app_theme_mode(self, value: object) -> str: @@ -4387,6 +4562,7 @@ def _on_pre_plot_style_changed(self, *_args) -> None: self.act_plot_grid.isChecked() if hasattr(self, "act_plot_grid") else True, persist=True, ) + self._record_pre_history_change() def _auto_range_for_processed(self, processed: ProcessedTrial) -> None: try: @@ -4559,6 +4735,7 @@ def _on_time_window_changed(self) -> None: checked_auto = [r for r in auto_win if not any(self._regions_match(r, ig) for ig in ignore_win)] self.artifact_panel.set_auto_regions(auto_win, checked_regions=checked_auto) self.artifact_panel.set_regions(manual_win) + self._record_pre_history_change() self._update_raw_plot() self._trigger_preview() self._update_plot_status() @@ -4961,6 +5138,8 @@ def _artifact_param_signature(self, params: ProcessingParams) -> Tuple[object, . ) def _on_params_changed(self) -> None: + if self._pre_history_restoring: + return try: params = self.param_panel.get_params() except Exception: @@ -4984,6 +5163,7 @@ def _on_params_changed(self) -> None: self._update_raw_plot(preserve_view=True) except Exception: pass + self._record_pre_history_change() self._trigger_preview(preserve_view=True) def _trigger_preview(self, preserve_view: bool = False) -> None: @@ -5141,6 +5321,7 @@ def _add_manual_region_from_selector(self) -> None: self._manual_regions_by_key[key] = regs start_s, end_s = self._time_window_bounds() self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s)) + self._record_pre_history_change() self._trigger_preview(preserve_view=True) def _add_manual_region_from_drag(self, t0: float, t1: float) -> None: @@ -5168,6 +5349,7 @@ def _clear_manual_regions_current(self) -> None: self._manual_exclude_by_key[key] = [] self._pending_box_region_by_key.pop(key, None) self.artifact_panel.set_regions([]) + self._record_pre_history_change() self._trigger_preview(preserve_view=True) def _request_box_select(self, callback: Callable[[float, float], None]) -> None: @@ -5227,6 +5409,7 @@ def _assign_pending_box_to_artifact(self) -> None: self._manual_regions_by_key[key] = regs start_s, end_s = self._time_window_bounds() self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s)) + self._record_pre_history_change() self._trigger_preview(preserve_view=True) def _assign_pending_box_to_cut(self) -> None: @@ -5239,6 +5422,7 @@ def _assign_pending_box_to_cut(self) -> None: regs.sort(key=lambda x: x[0]) self._cutout_regions_by_key[key] = regs self._last_processed.clear() + self._record_pre_history_change() self._update_raw_plot() self._trigger_preview() @@ -5255,6 +5439,7 @@ def _assign_pending_box_to_section(self) -> None: }) sections.sort(key=lambda sec: float(sec.get("start", 0.0))) self._sections_by_key[key] = sections + self._record_pre_history_change() self._show_status_message(f"Section added: {region[0]:.3f}s to {region[1]:.3f}s") def _show_box_selection_context_menu(self) -> None: @@ -5295,6 +5480,7 @@ def _contains(target: Tuple[float, float], arr: List[Tuple[float, float]]) -> bo prev_ignore = self._manual_exclude_by_key.get(key, []) self._manual_regions_by_key[key] = self._merge_regions_with_window(prev_manual, manual_add, start_s, end_s) self._manual_exclude_by_key[key] = self._merge_regions_with_window(prev_ignore, manual_ignore, start_s, end_s) + self._record_pre_history_change() self._trigger_preview(preserve_view=True) def _toggle_artifacts_panel(self) -> None: From e5f78dc1ee44d64b1420a8cfbf0da9d8ff62f440 Mon Sep 17 00:00:00 2001 From: andrianj Date: Thu, 7 May 2026 18:51:36 +0200 Subject: [PATCH 03/14] Improve temporal GLM behavior predictors --- pyBer/gui_postprocessing.py | 40 ++- pyBer/temporal_modeling.py | 651 ++++++++++++++++++++++++++++++++++-- 2 files changed, 641 insertions(+), 50 deletions(-) diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index 0a68d7c..3cc04ba 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -2452,6 +2452,7 @@ def _setup_section_popups(self) -> None: dock = QtWidgets.QDockWidget(title, host) dock.setObjectName(f"post.{key}.dock") + dock.setWindowModality(QtCore.Qt.WindowModality.NonModal) dock.setAllowedAreas(QtCore.Qt.DockWidgetArea.AllDockWidgetAreas) dock.setMinimumWidth(320) dock.setFeatures( @@ -2649,7 +2650,8 @@ def _toggle_section_popup(self, key: str, checked: bool) -> None: self._section_popup_initialized.add(key) dock.show() dock.raise_() - dock.activateWindow() + if key != "temporal": + dock.activateWindow() self._last_opened_section = key else: dock.hide() @@ -3283,6 +3285,24 @@ def _load_behavior_paths(self, paths: List[str], replace: bool) -> None: self._project_dirty = True self._update_data_availability() self._update_status_strip() + self._sync_temporal_modeling_context() + + def _sync_temporal_modeling_context(self) -> None: + if not hasattr(self, "section_temporal"): + return + try: + self.section_temporal.set_data( + processed_trials=self._processed, + psth_mat=self._last_mat, + psth_tvec=self._last_tvec, + event_times=self._last_events, + file_ids=self._all_file_ids, + per_file_mats=self._per_file_mats, + behavior_sources=self._behavior_sources, + event_rows=self._last_event_rows, + ) + except Exception: + _LOG.debug("Could not sync temporal modeling context", exc_info=True) def _load_processed_paths(self, paths: List[str], replace: bool) -> None: loaded: List[ProcessedTrial] = [] @@ -4191,6 +4211,7 @@ def _refresh_behavior_list(self) -> None: self._refresh_spatial_columns() self._compute_spatial_heatmap() self._update_data_availability() + self._sync_temporal_modeling_context() return behavior_names: set[str] = set() for info in self._behavior_sources.values(): @@ -4213,6 +4234,7 @@ def _refresh_behavior_list(self) -> None: self._compute_spatial_heatmap() self._update_data_availability() self._update_status_strip() + self._sync_temporal_modeling_context() def _guess_spatial_column(self, columns: List[str], axis: str) -> Optional[str]: if not columns: @@ -6386,6 +6408,7 @@ def _compute_psth(self) -> None: if hasattr(self, "lbl_global_metrics"): self.lbl_global_metrics.setText("Global metrics: -") self._update_status_strip() + self._sync_temporal_modeling_context() return # Update trace preview each time (also updates event lines) @@ -6474,6 +6497,7 @@ def _compute_psth(self) -> None: self._last_event_rows = [] self._last_durations = np.array([], float) self._update_status_strip() + self._sync_temporal_modeling_context() return mat_events = np.vstack(mats) @@ -6487,6 +6511,7 @@ def _compute_psth(self) -> None: self._last_event_rows = [] self._last_durations = np.array([], float) self._update_status_strip() + self._sync_temporal_modeling_context() return mat_display = self._group_mat display_labels = self._group_labels @@ -6534,18 +6559,7 @@ def _compute_psth(self) -> None: self._update_metric_regions() self._update_status_strip() self._save_settings() - # Feed data to temporal modeling widget - try: - self.section_temporal.set_data( - processed_trials=self._processed, - psth_mat=mat_display, - psth_tvec=tvec, - event_times=self._last_events, - file_ids=self._all_file_ids, - per_file_mats=self._per_file_mats, - ) - except Exception: - pass + self._sync_temporal_modeling_context() except Exception as e: self.statusUpdate.emit(f"Postprocessing error: {e}", 5000) self._update_status_strip() diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index bc1c0ba..0dc5e68 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -15,6 +15,7 @@ import logging import os +import re import traceback from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -87,6 +88,8 @@ def _init_r(): _R_READY = False +_BEHAVIOR_PARSE_BINARY = "binary_columns" +_BEHAVIOR_PARSE_TIMESTAMPS = "timestamp_columns" # ============================================================================ @@ -141,6 +144,7 @@ class GLMResult: predictor_names: List[str] kernels: Dict[str, np.ndarray] # predictor -> (n_kernel_samples,) kernel_tvec: np.ndarray # time vector for kernel x-axis + time: np.ndarray # time vector used for the fitted trace y_pred: np.ndarray # predicted trace y_actual: np.ndarray # actual trace residuals: np.ndarray @@ -158,14 +162,79 @@ class ContinuousGLM: def __init__(self): self._result: Optional[GLMResult] = None + @staticmethod + def _predictor_vector(time: np.ndarray, predictor: Any) -> np.ndarray: + """Convert an event or continuous predictor spec into a model-time vector.""" + t = np.asarray(time, float) + T = int(t.size) + vec = np.zeros(T, float) + if T == 0: + return vec + + if isinstance(predictor, dict): + kind = str(predictor.get("kind", "events")).strip().lower() + if kind in {"vector", "sampled"}: + values = np.asarray(predictor.get("values", []), float) + if values.size != T: + return vec + vec = values.astype(float, copy=True) + vec[~np.isfinite(vec)] = 0.0 + return vec + if kind == "continuous": + pt = np.asarray(predictor.get("time", []), float) + values = np.asarray(predictor.get("values", []), float) + if pt.size != values.size: + n = min(pt.size, values.size) + pt = pt[:n] + values = values[:n] + m = np.isfinite(pt) & np.isfinite(values) + pt = pt[m] + values = values[m] + if pt.size < 2 or values.size < 2: + return vec + order = np.argsort(pt) + pt = pt[order] + values = values[order] + keep = np.concatenate([[True], np.diff(pt) > 0]) + pt = pt[keep] + values = values[keep] + if pt.size < 2: + return vec + interp = np.interp(t, pt, values, left=np.nan, right=np.nan) + finite = np.isfinite(interp) + if not np.any(finite): + return vec + centered = interp.astype(float) + mean = float(np.nanmean(centered[finite])) + std = float(np.nanstd(centered[finite])) + if np.isfinite(std) and std > 1e-12: + centered = (centered - mean) / std + else: + centered = centered - mean + centered[~np.isfinite(centered)] = 0.0 + return centered + ev_times = predictor.get("events", predictor.get("times", [])) + else: + ev_times = predictor + + ev_times = np.asarray(ev_times, float) + ev_times = ev_times[np.isfinite(ev_times)] + if ev_times.size == 0: + return vec + ev_idx = np.searchsorted(t, ev_times) + ev_idx = ev_idx[(ev_idx >= 0) & (ev_idx < T)] + for idx in ev_idx: + vec[int(idx)] += 1.0 + return vec + @staticmethod def build_design_matrix( time: np.ndarray, - predictors: Dict[str, np.ndarray], + predictors: Dict[str, Any], kernel_window: Tuple[float, float], n_basis: int = 8, basis_type: str = "raised_cosine", - ) -> Tuple[np.ndarray, List[str], int]: + ) -> Tuple[np.ndarray, List[str], int, List[str]]: """ Build a (T x P) design matrix from event times. @@ -183,10 +252,15 @@ def build_design_matrix( col_names : column labels n_basis : basis count (for later kernel extraction) """ - dt = np.median(np.diff(time)) + time = np.asarray(time, float) + if time.size < 3: + raise ValueError("Need at least 3 time samples for GLM fitting.") + dt = float(np.nanmedian(np.diff(time))) + if not np.isfinite(dt) or dt <= 0: + raise ValueError("GLM time vector must be strictly increasing.") pre_samp = int(round(abs(kernel_window[0]) / dt)) post_samp = int(round(abs(kernel_window[1]) / dt)) - kernel_len = pre_samp + post_samp + kernel_len = max(2, pre_samp + post_samp) if basis_type == "bspline": B = _bspline_basis(n_basis, kernel_len) @@ -198,19 +272,16 @@ def build_design_matrix( T = len(time) col_names: List[str] = [] X_parts: List[np.ndarray] = [] + used_predictors: List[str] = [] - for pred_name, ev_times in predictors.items(): - ev_times = np.asarray(ev_times, float) - ev_times = ev_times[np.isfinite(ev_times)] - - # Convert event times to sample indices - ev_idx = np.searchsorted(time, ev_times) - ev_idx = ev_idx[(ev_idx >= 0) & (ev_idx < T)] - - # Build impulse vector - impulse = np.zeros(T, float) - for idx in ev_idx: - impulse[idx] = 1.0 + for pred_name, pred_spec in predictors.items(): + input_vec = ContinuousGLM._predictor_vector(time, pred_spec) + if input_vec.size != T or not np.any(np.isfinite(input_vec)): + continue + input_vec = np.asarray(input_vec, float) + input_vec[~np.isfinite(input_vec)] = 0.0 + if not np.any(np.abs(input_vec) > 1e-12): + continue # Convolve impulse with each basis function part = np.zeros((T, n_basis), float) @@ -218,7 +289,7 @@ def build_design_matrix( # Pad the basis to align with pre_samp offset kernel = np.zeros(kernel_len) kernel[:] = B[:, b] - conv = np.convolve(impulse, kernel, mode="full")[:T] + conv = np.convolve(input_vec, kernel, mode="full")[:T] # Shift so that the kernel starts at -pre_samp if pre_samp > 0: part[:, b] = np.roll(conv, -pre_samp) @@ -227,6 +298,7 @@ def build_design_matrix( part[:, b] = conv X_parts.append(part) + used_predictors.append(pred_name) for b in range(n_basis): col_names.append(f"{pred_name}_b{b}") @@ -234,13 +306,13 @@ def build_design_matrix( # Add intercept X = np.column_stack([np.ones(T), X]) col_names.insert(0, "intercept") - return X, col_names, n_basis + return X, col_names, n_basis, used_predictors def fit( self, time: np.ndarray, signal: np.ndarray, - predictors: Dict[str, np.ndarray], + predictors: Dict[str, Any], kernel_window: Tuple[float, float] = (-1.0, 3.0), n_basis: int = 8, basis_type: str = "raised_cosine", @@ -251,9 +323,11 @@ def fit( time = np.asarray(time, float) signal = np.asarray(signal, float) - X, col_names, n_b = self.build_design_matrix( + X, col_names, n_b, used_predictors = self.build_design_matrix( time, predictors, kernel_window, n_basis, basis_type, ) + if not used_predictors: + raise ValueError("No usable predictors after alignment to the fitted trace.") # Mask NaN samples valid = np.isfinite(signal) @@ -285,7 +359,7 @@ def fit( dt = np.median(np.diff(time)) pre_samp = int(round(abs(kernel_window[0]) / dt)) post_samp = int(round(abs(kernel_window[1]) / dt)) - kernel_len = pre_samp + post_samp + kernel_len = max(2, pre_samp + post_samp) if basis_type == "bspline": B = _bspline_basis(n_b, kernel_len) @@ -297,15 +371,16 @@ def fit( kernel_tvec = np.linspace(kernel_window[0], kernel_window[1], kernel_len) kernels: Dict[str, np.ndarray] = {} idx = 1 # skip intercept - for pred_name in predictors: + for pred_name in used_predictors: w = beta[idx:idx + n_b] kernels[pred_name] = B @ w idx += n_b self._result = GLMResult( - predictor_names=list(predictors.keys()), + predictor_names=list(used_predictors), kernels=kernels, kernel_tvec=kernel_tvec, + time=time, y_pred=y_pred, y_actual=signal, residuals=residuals, @@ -658,6 +733,10 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._psth_tvec: Optional[np.ndarray] = None self._event_times: Optional[np.ndarray] = None self._file_ids: List[str] = [] + self._per_file_mats: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} + self._behavior_sources: Dict[str, Dict[str, Any]] = {} + self._event_rows: List[Dict[str, object]] = [] + self._predictor_catalog: Dict[str, Dict[str, Any]] = {} self._build_compact_ui() self._connect_signals() @@ -768,6 +847,10 @@ def _build_ui(self): pl = QtWidgets.QVBoxLayout(self.grp_predictors) pl.setSpacing(4) + self.combo_available_predictors = QtWidgets.QComboBox() + self.combo_available_predictors.setToolTip("Choose from loaded event, behavior, and numeric behavior columns.") + pl.addWidget(self.combo_available_predictors) + self.list_predictors = QtWidgets.QListWidget() self.list_predictors.setMaximumHeight(100) pl.addWidget(self.list_predictors) @@ -783,7 +866,7 @@ def _build_ui(self): pl.addLayout(pred_btn_row) self.lbl_predictor_hint = QtWidgets.QLabel( - "Predictors are auto-populated from DIO / behavior events when PSTH is computed." + "Choose predictors from loaded DIO events, behavior states, behavior onsets, or numeric behavior columns." ) self.lbl_predictor_hint.setProperty("class", "hint") self.lbl_predictor_hint.setWordWrap(True) @@ -1021,6 +1104,10 @@ def _build_predictor_page(self): pl.setContentsMargins(12, 18, 12, 12) pl.setSpacing(8) + self.combo_available_predictors = QtWidgets.QComboBox() + self.combo_available_predictors.setToolTip("Choose from loaded event, behavior, and numeric behavior columns.") + pl.addWidget(self.combo_available_predictors) + self.list_predictors = QtWidgets.QListWidget() self.list_predictors.setMinimumHeight(220) pl.addWidget(self.list_predictors) @@ -1034,7 +1121,7 @@ def _build_predictor_page(self): pl.addLayout(row) self.lbl_predictor_hint = QtWidgets.QLabel( - "Predictors are populated from DIO or behavior events when PSTH is computed." + "Predictors are populated from loaded DIO events, behavior states, behavior onsets, and numeric behavior columns." ) self.lbl_predictor_hint.setProperty("class", "muted") self.lbl_predictor_hint.setWordWrap(True) @@ -1153,6 +1240,8 @@ def set_data( event_times: Optional[np.ndarray] = None, file_ids: Optional[List[str]] = None, per_file_mats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, + behavior_sources: Optional[Dict[str, Dict[str, Any]]] = None, + event_rows: Optional[List[Dict[str, object]]] = None, ): """Push data from the host panel into this widget.""" self._processed_trials = processed_trials or [] @@ -1161,10 +1250,10 @@ def set_data( self._event_times = event_times self._file_ids = file_ids or [] self._per_file_mats = per_file_mats or {} + self._behavior_sources = dict(behavior_sources or {}) + self._event_rows = list(event_rows or []) + self._refresh_predictor_catalog() - # Auto-populate predictors for GLM - if self.list_predictors.count() == 0 and event_times is not None and len(event_times): - self.list_predictors.addItem("events") n_trials = len(self._processed_trials) psth_shape = tuple(np.shape(psth_mat)) if psth_mat is not None else None bits = [f"Processed recordings: {n_trials}"] @@ -1172,8 +1261,436 @@ def set_data( bits.append(f"PSTH matrix: {psth_shape[0]} x {psth_shape[1]}") if event_times is not None: bits.append(f"Events: {len(event_times)}") + if self._behavior_sources: + bits.append(f"Behavior files: {len(self._behavior_sources)}") + if self._predictor_catalog: + bits.append(f"Available predictors: {len(self._predictor_catalog)}") self.lbl_data_status.setText("\n".join(bits)) + # ------------------------------------------------------------------ + # Predictor catalog and extraction + # ------------------------------------------------------------------ + + @staticmethod + def _proc_file_id(proc: Any, fallback: str = "import") -> str: + path = str(getattr(proc, "path", "") or "").strip() + if not path: + return fallback + stem = os.path.splitext(os.path.basename(path))[0] + return stem or fallback + + @staticmethod + def _clean_id(value: object) -> str: + text = str(value or "").strip().lower() + text = re.sub(r"_ain0*[0-9]+$", "", text) + text = re.sub(r"[^a-z0-9]+", "", text) + return text + + def _ids_match(self, a: object, b: object) -> bool: + aa = self._clean_id(a) + bb = self._clean_id(b) + return bool(aa and bb and aa == bb) + + def _predictor_label(self, key: str) -> str: + entry = self._predictor_catalog.get(str(key), {}) + label = str(entry.get("label", "") or "").strip() + if label: + return label + if key == "events": + return "PSTH alignment events" + if key == "dio": + return "DIO rising edges" + if key.startswith("trigger::"): + return f"Trigger: {key.split('::', 1)[1]}" + if key.startswith("behavior_event::"): + return f"Behavior onset: {key.split('::', 1)[1]}" + if key.startswith("behavior_state::"): + return f"Behavior state: {key.split('::', 1)[1]}" + if key.startswith("behavior_cont::"): + return f"Numeric column: {key.split('::', 1)[1]}" + return str(key) + + def _selected_predictor_keys(self) -> List[str]: + keys: List[str] = [] + for i in range(self.list_predictors.count()): + item = self.list_predictors.item(i) + key = item.data(QtCore.Qt.ItemDataRole.UserRole) + if not isinstance(key, str) or not key.strip(): + key = item.text().strip() + for cat_key, entry in self._predictor_catalog.items(): + if key == cat_key or key == str(entry.get("label", "")): + key = cat_key + break + key = str(key).strip() + if key and key not in keys: + keys.append(key) + return keys + + def _add_predictor_item(self, key: str) -> bool: + key = str(key or "").strip() + if not key: + return False + for existing in self._selected_predictor_keys(): + if existing == key: + return False + item = QtWidgets.QListWidgetItem(self._predictor_label(key)) + item.setData(QtCore.Qt.ItemDataRole.UserRole, key) + self.list_predictors.addItem(item) + return True + + def _refresh_predictor_combo(self) -> None: + if not hasattr(self, "combo_available_predictors"): + return + selected = self.combo_available_predictors.currentData(QtCore.Qt.ItemDataRole.UserRole) + self.combo_available_predictors.blockSignals(True) + self.combo_available_predictors.clear() + for key, entry in self._predictor_catalog.items(): + self.combo_available_predictors.addItem(str(entry.get("label", key)), key) + if not self._predictor_catalog: + self.combo_available_predictors.addItem("No predictors available yet", "") + if isinstance(selected, str) and selected: + idx = self.combo_available_predictors.findData(selected, QtCore.Qt.ItemDataRole.UserRole) + if idx >= 0: + self.combo_available_predictors.setCurrentIndex(idx) + self.combo_available_predictors.blockSignals(False) + + def _refresh_predictor_catalog(self) -> None: + catalog: Dict[str, Dict[str, Any]] = {} + if self._event_times is not None and len(np.asarray(self._event_times, float)): + catalog["events"] = {"kind": "event", "label": "PSTH alignment events"} + if self._event_rows: + catalog["events"] = {"kind": "event", "label": "PSTH alignment events"} + + trigger_names: set[str] = set() + has_dio = False + for proc in self._processed_trials: + if getattr(proc, "dio", None) is not None: + has_dio = True + triggers = getattr(proc, "triggers", None) or {} + if isinstance(triggers, dict): + trigger_names.update(str(k) for k in triggers.keys() if str(k).strip()) + if has_dio: + catalog["dio"] = {"kind": "event", "label": "DIO rising edges"} + for name in sorted(trigger_names): + catalog[f"trigger::{name}"] = {"kind": "event", "name": name, "label": f"Trigger: {name}"} + + behavior_names: set[str] = set() + state_names: set[str] = set() + continuous_names: set[str] = set() + for info in self._behavior_sources.values(): + if not isinstance(info, dict): + continue + kind = str(info.get("kind", _BEHAVIOR_PARSE_BINARY)) + behaviors = info.get("behaviors") or {} + for name in behaviors.keys(): + clean = str(name).strip() + if not clean: + continue + behavior_names.add(clean) + if kind != _BEHAVIOR_PARSE_TIMESTAMPS: + state_names.add(clean) + trajectory = info.get("trajectory") or {} + for name in trajectory.keys(): + clean = str(name).strip() + if clean: + continuous_names.add(clean) + for name in sorted(behavior_names): + catalog[f"behavior_event::{name}"] = { + "kind": "behavior_event", + "name": name, + "label": f"Behavior onset: {name}", + } + for name in sorted(state_names): + catalog[f"behavior_state::{name}"] = { + "kind": "continuous", + "name": name, + "label": f"Behavior state: {name}", + } + for name in sorted(continuous_names): + catalog[f"behavior_cont::{name}"] = { + "kind": "continuous", + "name": name, + "label": f"Numeric column: {name}", + } + + previous_keys = self._selected_predictor_keys() if hasattr(self, "list_predictors") else [] + self._predictor_catalog = catalog + self._refresh_predictor_combo() + if hasattr(self, "list_predictors"): + self.list_predictors.clear() + for key in previous_keys: + if key in catalog: + self._add_predictor_item(key) + if self.list_predictors.count() == 0 and "events" in catalog: + self._add_predictor_item("events") + + def _behavior_source_for_proc(self, proc: Any) -> Optional[Dict[str, Any]]: + if not self._behavior_sources: + return None + file_id = self._proc_file_id(proc) + info = self._behavior_sources.get(file_id) + if info is not None: + return info + for key, val in self._behavior_sources.items(): + if self._ids_match(key, file_id): + return val + try: + idx = next(i for i, p in enumerate(self._processed_trials) if (p is proc) or (getattr(p, "path", "") == getattr(proc, "path", ""))) + except StopIteration: + idx = None + if idx is not None: + keys = list(self._behavior_sources.keys()) + if 0 <= idx < len(keys): + return self._behavior_sources.get(keys[idx]) + if len(self._behavior_sources) == 1: + return next(iter(self._behavior_sources.values())) + return None + + @staticmethod + def _event_vector(time: np.ndarray, events: np.ndarray) -> np.ndarray: + t = np.asarray(time, float) + out = np.zeros(t.size, float) + ev = np.asarray(events, float) + ev = ev[np.isfinite(ev)] + if t.size == 0 or ev.size == 0: + return out + idx = np.searchsorted(t, ev) + idx = idx[(idx >= 0) & (idx < t.size)] + for i in idx: + out[int(i)] += 1.0 + return out + + def _events_for_proc(self, proc: Any, time: np.ndarray) -> np.ndarray: + file_id = self._proc_file_id(proc) + if self._event_rows: + vals = [] + for row in self._event_rows: + if self._ids_match(row.get("file_id", ""), file_id): + try: + vals.append(float(row.get("event_time_sec", np.nan))) + except Exception: + pass + return np.asarray(vals, float) + if len(self._processed_trials) == 1 and self._event_times is not None: + return np.asarray(self._event_times, float) + return np.array([], float) + + @staticmethod + def _rising_edges_from_signal(time: np.ndarray, signal: np.ndarray) -> np.ndarray: + t = np.asarray(time, float) + x = np.asarray(signal, float) + if t.size < 2 or x.size != t.size: + return np.array([], float) + b = x > 0.5 + idx = np.where((~b[:-1]) & (b[1:]))[0] + 1 + return t[idx] + + @staticmethod + def _behavior_onsets(info: Dict[str, Any], name: str) -> np.ndarray: + behaviors = info.get("behaviors") or {} + if name not in behaviors: + return np.array([], float) + kind = str(info.get("kind", _BEHAVIOR_PARSE_BINARY)) + if kind == _BEHAVIOR_PARSE_TIMESTAMPS: + ev = np.asarray(behaviors[name], float) + ev = ev[np.isfinite(ev)] + return np.sort(np.unique(ev)) + t = np.asarray(info.get("time", np.array([], float)), float) + x = np.asarray(behaviors[name], float) + if t.size < 1 or x.size != t.size: + return np.array([], float) + b = x > 0.5 + on = np.where((~b[:-1]) & (b[1:]))[0] + 1 if b.size > 1 else np.array([], int) + if b.size and bool(b[0]): + on = np.concatenate([[0], on]) + return t[on] + + @staticmethod + def _interp_to_time(target_time: np.ndarray, source_time: np.ndarray, values: np.ndarray) -> np.ndarray: + target_time = np.asarray(target_time, float) + source_time = np.asarray(source_time, float) + values = np.asarray(values, float) + out = np.zeros(target_time.size, float) + if source_time.size != values.size: + n = min(source_time.size, values.size) + source_time = source_time[:n] + values = values[:n] + m = np.isfinite(source_time) & np.isfinite(values) + source_time = source_time[m] + values = values[m] + if source_time.size == target_time.size and np.allclose(source_time, target_time, equal_nan=False): + out = values.astype(float, copy=True) + out[~np.isfinite(out)] = 0.0 + return out + if source_time.size < 2: + if values.size == target_time.size: + out = values.astype(float, copy=True) + out[~np.isfinite(out)] = 0.0 + return out + order = np.argsort(source_time) + source_time = source_time[order] + values = values[order] + keep = np.concatenate([[True], np.diff(source_time) > 0]) + source_time = source_time[keep] + values = values[keep] + if source_time.size < 2: + return out + interp = np.interp(target_time, source_time, values, left=np.nan, right=np.nan) + interp[~np.isfinite(interp)] = 0.0 + return interp + + def _predictor_vector_for_proc(self, key: str, proc: Any, time: np.ndarray) -> Tuple[np.ndarray, str]: + entry = self._predictor_catalog.get(key, {}) + if key == "events": + return self._event_vector(time, self._events_for_proc(proc, time)), "event" + if key == "dio": + dio = getattr(proc, "dio", None) + if dio is None: + return np.zeros(time.size, float), "event" + return self._event_vector(time, self._rising_edges_from_signal(time, dio)), "event" + if key.startswith("trigger::"): + name = key.split("::", 1)[1] + triggers = getattr(proc, "triggers", None) or {} + sig = triggers.get(name) if isinstance(triggers, dict) else None + if sig is None: + return np.zeros(time.size, float), "event" + return self._event_vector(time, self._rising_edges_from_signal(time, sig)), "event" + + info = self._behavior_source_for_proc(proc) + if not isinstance(info, dict): + return np.zeros(time.size, float), str(entry.get("kind", "event")) + name = str(entry.get("name", "") or key.split("::")[-1]) + if key.startswith("behavior_event::"): + return self._event_vector(time, self._behavior_onsets(info, name)), "event" + if key.startswith("behavior_state::"): + behaviors = info.get("behaviors") or {} + values = behaviors.get(name) + source_time = np.asarray(info.get("time", np.array([], float)), float) + if values is None or source_time.size == 0: + return np.zeros(time.size, float), "continuous" + return self._interp_to_time(time, source_time, values), "continuous" + if key.startswith("behavior_cont::"): + trajectory = info.get("trajectory") or {} + values = trajectory.get(name) + source_time = np.asarray(info.get("trajectory_time", np.array([], float)), float) + if values is None or source_time.size == 0: + return np.zeros(time.size, float), "continuous" + return self._interp_to_time(time, source_time, values), "continuous" + return np.zeros(time.size, float), str(entry.get("kind", "event")) + + def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: + selected = self._selected_predictor_keys() + if not selected: + return {"error": "Choose at least one predictor before fitting."} + if not self._processed_trials: + return {"error": "No processed recordings are loaded."} + + kernel_span = abs(float(self.spin_kernel_pre.value())) + abs(float(self.spin_kernel_post.value())) + segments: List[Tuple[str, np.ndarray, np.ndarray, Any]] = [] + dropped_records: List[str] = [] + for idx, proc in enumerate(self._processed_trials): + t = np.asarray(getattr(proc, "time", np.array([], float)), float) + y_raw = getattr(proc, "output", None) + y = np.asarray(y_raw, float) if y_raw is not None else np.array([], float) + file_id = self._proc_file_id(proc, fallback=f"file_{idx + 1}") + if t.size < 3 or y.size != t.size: + dropped_records.append(file_id) + continue + m = np.isfinite(t) + t = t[m] + y = y[m] + if t.size < 3: + dropped_records.append(file_id) + continue + order = np.argsort(t) + t = t[order] + y = y[order] + keep = np.concatenate([[True], np.diff(t) > 0]) + t = t[keep] + y = y[keep] + if t.size < 3: + dropped_records.append(file_id) + continue + segments.append((file_id, t, y, proc)) + + if not segments: + return {"error": "No recordings have usable time and output traces."} + + time_parts: List[np.ndarray] = [] + signal_parts: List[np.ndarray] = [] + vec_parts: Dict[str, List[np.ndarray]] = {key: [] for key in selected} + pred_types: Dict[str, str] = {} + used_records: List[str] = [] + offset = 0.0 + for seg_idx, (file_id, t, y, proc) in enumerate(segments): + dt = float(np.nanmedian(np.diff(t))) + if not np.isfinite(dt) or dt <= 0: + dropped_records.append(file_id) + continue + t_shift = (t - float(t[0])) + offset + time_parts.append(t_shift) + signal_parts.append(y.astype(float, copy=True)) + used_records.append(file_id) + for key in selected: + vec, ptype = self._predictor_vector_for_proc(key, proc, t) + vec = np.asarray(vec, float) + if vec.size != t.size: + vec = np.zeros(t.size, float) + vec[~np.isfinite(vec)] = 0.0 + vec_parts[key].append(vec) + pred_types[key] = ptype + + if seg_idx < len(segments) - 1: + pad_n = max(2, int(np.ceil((kernel_span + dt) / dt))) + pad_t = t_shift[-1] + dt * np.arange(1, pad_n + 1, dtype=float) + time_parts.append(pad_t) + signal_parts.append(np.full(pad_n, np.nan, float)) + for key in selected: + vec_parts[key].append(np.zeros(pad_n, float)) + offset = float(pad_t[-1] + dt) + + if not time_parts: + return {"error": "No recordings could be aligned for GLM fitting."} + time = np.concatenate(time_parts) + signal = np.concatenate(signal_parts) + valid_signal = np.isfinite(signal) + predictors: Dict[str, Dict[str, Any]] = {} + dropped_predictors: List[str] = [] + for key in selected: + vec = np.concatenate(vec_parts.get(key, [np.zeros(time.size, float)])) + vec = vec.astype(float, copy=True) + vec[~np.isfinite(vec)] = 0.0 + if pred_types.get(key) == "continuous": + finite = valid_signal & np.isfinite(vec) + vals = vec[finite] + if vals.size: + mean = float(np.nanmean(vals)) + std = float(np.nanstd(vals)) + if np.isfinite(std) and std > 1e-12: + vec[finite] = (vec[finite] - mean) / std + else: + vec[finite] = vec[finite] - mean + vec[~valid_signal] = 0.0 + if not np.any(np.abs(vec[valid_signal]) > 1e-12): + dropped_predictors.append(self._predictor_label(key)) + continue + predictors[key] = {"kind": "vector", "values": vec} + + if not predictors: + return { + "error": "The selected predictors contain no usable events or variation for the loaded recordings.", + "dropped_predictors": dropped_predictors, + } + return { + "time": time, + "signal": signal, + "predictors": predictors, + "used_records": used_records, + "dropped_records": dropped_records, + "dropped_predictors": dropped_predictors, + "valid_samples": int(np.sum(valid_signal)), + } + # ------------------------------------------------------------------ # Slots # ------------------------------------------------------------------ @@ -1204,11 +1721,15 @@ def _on_model_type_changed(self, index: int): self.lbl_flmm_status.setStyleSheet("color: #f5a97f;") def _on_add_predictor(self): - name, ok = QtWidgets.QInputDialog.getText( - self, "Add predictor", "Predictor name (must match a column in design):" - ) - if ok and name.strip(): - self.list_predictors.addItem(name.strip()) + key = "" + if hasattr(self, "combo_available_predictors"): + data = self.combo_available_predictors.currentData(QtCore.Qt.ItemDataRole.UserRole) + key = str(data or "").strip() + if not key: + self.statusMessage.emit("No predictor is available yet. Load or compute behavior/events first.", 5000) + return + if self._add_predictor_item(key): + self.statusMessage.emit(f"Added predictor: {self._predictor_label(key)}", 3000) def _on_remove_predictor(self): sel = self.list_predictors.currentRow() @@ -1231,7 +1752,63 @@ def _on_fit_clicked(self): # GLM fit # ------------------------------------------------------------------ + def _fit_glm_catalog(self) -> None: + dataset = self._build_glm_dataset_from_selected_predictors() + if "error" in dataset: + msg = str(dataset.get("error", "Could not build GLM dataset.")) + dropped = dataset.get("dropped_predictors", []) or [] + if dropped: + msg = f"{msg}\nDropped: {', '.join(str(v) for v in dropped)}" + self.txt_summary.setPlainText(msg) + self.statusMessage.emit(msg.splitlines()[0], 7000) + self._select_control_page(1) + return + + basis_map = {"Raised cosine": "raised_cosine", "B-spline": "bspline", "FIR": "fir"} + reg_map = {"Ridge": "ridge", "Lasso": "lasso", "OLS": "ols"} + kernel_win = (self.spin_kernel_pre.value(), self.spin_kernel_post.value()) + result = self._glm.fit( + np.asarray(dataset["time"], float), + np.asarray(dataset["signal"], float), + dataset["predictors"], + kernel_window=kernel_win, + n_basis=self.spin_n_basis.value(), + basis_type=basis_map.get(self.combo_basis.currentText(), "raised_cosine"), + regularization=reg_map.get(self.combo_reg.currentText(), "ridge"), + alpha=self.spin_alpha.value(), + ) + self._glm_result = result + + used_labels = [self._predictor_label(k) for k in result.predictor_names] + dropped_predictors = dataset.get("dropped_predictors", []) or [] + used_records = dataset.get("used_records", []) or [] + dropped_records = dataset.get("dropped_records", []) or [] + record_preview = ", ".join(str(v) for v in used_records[:6]) + if len(used_records) > 6: + record_preview += "..." + lines = [ + f"Continuous GLM - R^2 = {result.r2:.4f}", + f"Recordings used: {len(used_records)} ({record_preview})", + f"Samples fit: {int(dataset.get('valid_samples', 0))}", + f"Predictors: {', '.join(used_labels)}", + f"Basis: {self.combo_basis.currentText()}, n={self.spin_n_basis.value()}", + f"Regularization: {self.combo_reg.currentText()}, alpha={self.spin_alpha.value():.3f}", + ] + if dropped_predictors: + lines.append(f"Dropped predictors: {', '.join(str(v) for v in dropped_predictors)}") + if dropped_records: + lines.append(f"Dropped recordings: {', '.join(str(v) for v in dropped_records)}") + self.txt_summary.setPlainText("\n".join(lines)) + + self._plot_glm_kernels(result) + self._plot_glm_fit(result) + if hasattr(self, "tabs_workspace"): + self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) + self.statusMessage.emit(f"GLM fit complete - R^2 = {result.r2:.4f}", 5000) + def _fit_glm(self): + self._fit_glm_catalog() + return if not self._processed_trials: self.statusMessage.emit("No processed data — run preprocessing first.", 5000) return @@ -1308,7 +1885,7 @@ def _plot_glm_kernels(self, result: GLMResult): "#f5e0dc", "#89dceb", "#fab387"] for i, (name, kernel) in enumerate(result.kernels.items()): color = colors[i % len(colors)] - pw.plot(result.kernel_tvec, kernel, pen=pg.mkPen(color, width=2), name=name) + pw.plot(result.kernel_tvec, kernel, pen=pg.mkPen(color, width=2), name=self._predictor_label(name)) pw.setLabel("bottom", "Time", units="s") pw.setLabel("left", "Kernel weight") # Zero line @@ -1322,17 +1899,17 @@ def _plot_glm_fit(self, result: GLMResult): pw.getPlotItem().legend.clear() except Exception: pass - x = np.arange(result.y_actual.size) + x = np.asarray(result.time, float) if result.time is not None else np.arange(result.y_actual.size) pw.plot(x, result.y_actual, pen=pg.mkPen("#4b9df8", width=1.2), name="actual") pw.plot(x, result.y_pred, pen=pg.mkPen("#f5a97f", width=1.4), name="predicted") - pw.setLabel("bottom", "Sample") + pw.setLabel("bottom", "Time", units="s") pw.setLabel("left", "Signal") rw = self.plot_residuals rw.clear() rw.plot(x, result.residuals, pen=pg.mkPen("#ee99a0", width=1.1), name="residual") rw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) - rw.setLabel("bottom", "Sample") + rw.setLabel("bottom", "Time", units="s") rw.setLabel("left", "Residual") # ------------------------------------------------------------------ From b92df4b691a7df81bbc6747806e37b2a10e78044 Mon Sep 17 00:00:00 2001 From: andrianj Date: Thu, 7 May 2026 19:02:26 +0200 Subject: [PATCH 04/14] Support animal group temporal modeling --- pyBer/gui_postprocessing.py | 6 + pyBer/temporal_modeling.py | 214 +++++++++++++++++++++++++++++++++++- 2 files changed, 215 insertions(+), 5 deletions(-) diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index 3cc04ba..7cb5627 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -3300,6 +3300,11 @@ def _sync_temporal_modeling_context(self) -> None: per_file_mats=self._per_file_mats, behavior_sources=self._behavior_sources, event_rows=self._last_event_rows, + group_mat=self._group_mat, + group_tvec=self._group_tvec, + group_labels=self._group_labels, + visual_mode=int(self.tab_visual_mode.currentIndex()) if hasattr(self, "tab_visual_mode") else 0, + group_mode=bool(self.tab_sources.currentIndex() == 1) if hasattr(self, "tab_sources") else False, ) except Exception: _LOG.debug("Could not sync temporal modeling context", exc_info=True) @@ -3594,6 +3599,7 @@ def _rerender_visual_from_cache(self) -> None: self.plot_avg.setTitle("Average across trials +/- SEM") # Always refresh the trace preview to match the selected file self._update_trace_preview() + self._sync_temporal_modeling_context() def _update_data_availability(self) -> None: has_processed = bool(self._processed) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index 0dc5e68..267c2f7 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -441,7 +441,7 @@ def fit( mat: np.ndarray, tvec: np.ndarray, design: Dict[str, np.ndarray], - formula_fixed: str = "Y.obs ~ group", + formula_fixed: str = "Y.obs ~ 1", random_effects: str = "~1", group_var: str = "subject", parallel: bool = False, @@ -737,6 +737,11 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._behavior_sources: Dict[str, Dict[str, Any]] = {} self._event_rows: List[Dict[str, object]] = [] self._predictor_catalog: Dict[str, Dict[str, Any]] = {} + self._group_mat: Optional[np.ndarray] = None + self._group_tvec: Optional[np.ndarray] = None + self._group_labels: List[str] = [] + self._visual_mode: int = 0 + self._group_mode: bool = False self._build_compact_ui() self._connect_signals() @@ -816,8 +821,8 @@ def _build_ui(self): fl = QtWidgets.QFormLayout(self.grp_flmm) fl.setSpacing(4) - self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ group") - self.edit_formula.setPlaceholderText("e.g. Y.obs ~ group + condition") + self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ 1") + self.edit_formula.setPlaceholderText("Leave as Y.obs ~ 1 to auto-use selected predictors") fl.addRow("Fixed formula:", self.edit_formula) self.edit_random = QtWidgets.QLineEdit("~1") @@ -1071,8 +1076,8 @@ def _build_model_page(self): self.lbl_flmm_status.setWordWrap(True) fl.addRow("Backend", self.lbl_flmm_status) - self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ group") - self.edit_formula.setPlaceholderText("e.g. Y.obs ~ group + condition") + self.edit_formula = QtWidgets.QLineEdit("Y.obs ~ 1") + self.edit_formula.setPlaceholderText("Leave as Y.obs ~ 1 to auto-use selected predictors") fl.addRow("Fixed formula", self.edit_formula) self.edit_random = QtWidgets.QLineEdit("~1") self.edit_random.setPlaceholderText("e.g. ~1 or ~time") @@ -1242,6 +1247,11 @@ def set_data( per_file_mats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, behavior_sources: Optional[Dict[str, Dict[str, Any]]] = None, event_rows: Optional[List[Dict[str, object]]] = None, + group_mat: Optional[np.ndarray] = None, + group_tvec: Optional[np.ndarray] = None, + group_labels: Optional[List[str]] = None, + visual_mode: int = 0, + group_mode: bool = False, ): """Push data from the host panel into this widget.""" self._processed_trials = processed_trials or [] @@ -1252,6 +1262,11 @@ def set_data( self._per_file_mats = per_file_mats or {} self._behavior_sources = dict(behavior_sources or {}) self._event_rows = list(event_rows or []) + self._group_mat = np.asarray(group_mat, float) if group_mat is not None else None + self._group_tvec = np.asarray(group_tvec, float) if group_tvec is not None else None + self._group_labels = list(group_labels or []) + self._visual_mode = int(visual_mode) if isinstance(visual_mode, (int, np.integer)) else 0 + self._group_mode = bool(group_mode) self._refresh_predictor_catalog() n_trials = len(self._processed_trials) @@ -1263,6 +1278,8 @@ def set_data( bits.append(f"Events: {len(event_times)}") if self._behavior_sources: bits.append(f"Behavior files: {len(self._behavior_sources)}") + if self._group_mat is not None and self._group_labels: + bits.append(f"Group animals: {len(self._group_labels)}") if self._predictor_catalog: bits.append(f"Available predictors: {len(self._predictor_catalog)}") self.lbl_data_status.setText("\n".join(bits)) @@ -1691,6 +1708,119 @@ def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: "valid_samples": int(np.sum(valid_signal)), } + def _proc_for_file_id(self, file_id: str) -> Optional[Any]: + for proc in self._processed_trials: + if self._ids_match(self._proc_file_id(proc), file_id): + return proc + try: + idx = self._group_labels.index(file_id) + except ValueError: + idx = -1 + if 0 <= idx < len(self._processed_trials): + return self._processed_trials[idx] + return None + + @staticmethod + def _safe_design_name(label: str, used: set[str]) -> str: + base = re.sub(r"[^0-9A-Za-z_]+", "_", str(label or "").strip()) + base = re.sub(r"_+", "_", base).strip("_") or "predictor" + if base[0].isdigit(): + base = f"pred_{base}" + name = base + i = 2 + while name in used: + name = f"{base}_{i}" + i += 1 + used.add(name) + return name + + def _flmm_matrix_and_labels(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[str], str]: + group_mat = np.asarray(self._group_mat, float) if self._group_mat is not None else None + group_tvec = np.asarray(self._group_tvec, float) if self._group_tvec is not None else None + if ( + group_mat is not None + and group_tvec is not None + and group_mat.ndim == 2 + and group_mat.shape[0] >= 2 + and len(self._group_labels) == group_mat.shape[0] + ): + return group_mat, group_tvec, list(self._group_labels), "animal" + + mat = np.asarray(self._psth_mat, float) if self._psth_mat is not None else None + tvec = np.asarray(self._psth_tvec, float) if self._psth_tvec is not None else None + if mat is None or tvec is None or mat.ndim != 2: + return None, None, [], "none" + if len(self._file_ids) == mat.shape[0]: + labels = list(self._file_ids) + scope = "animal" + else: + labels = [f"trial_{i + 1}" for i in range(mat.shape[0])] + scope = "trial" + return mat, tvec, labels, scope + + def _animal_covariate_value(self, key: str, file_id: str) -> float: + proc = self._proc_for_file_id(file_id) + if proc is None: + return np.nan + t = np.asarray(getattr(proc, "time", np.array([], float)), float) + if t.size < 2: + return np.nan + + if key in {"events", "dio"} or key.startswith("trigger::") or key.startswith("behavior_event::"): + vec, _ = self._predictor_vector_for_proc(key, proc, t) + return float(np.nansum(np.asarray(vec, float))) + + entry = self._predictor_catalog.get(key, {}) + name = str(entry.get("name", "") or key.split("::")[-1]) + info = self._behavior_source_for_proc(proc) + if not isinstance(info, dict): + return np.nan + if key.startswith("behavior_state::"): + behaviors = info.get("behaviors") or {} + values = np.asarray(behaviors.get(name, np.array([], float)), float) + finite = values[np.isfinite(values)] + if finite.size == 0: + return np.nan + return float(np.nanmean(finite > 0.5)) + if key.startswith("behavior_cont::"): + trajectory = info.get("trajectory") or {} + values = np.asarray(trajectory.get(name, np.array([], float)), float) + finite = values[np.isfinite(values)] + if finite.size == 0: + return np.nan + return float(np.nanmean(finite)) + return np.nan + + def _build_flmm_design( + self, + labels: List[str], + group_var: str, + ) -> Tuple[Dict[str, np.ndarray], List[str], List[str], Dict[str, str]]: + design: Dict[str, np.ndarray] = {group_var: np.asarray(labels, dtype=object)} + terms: List[str] = [] + dropped: List[str] = [] + term_labels: Dict[str, str] = {} + used_names = {group_var, "Y.obs", ".index", ".obs"} + + for key in self._selected_predictor_keys(): + values = np.asarray([self._animal_covariate_value(key, label) for label in labels], float) + finite = np.isfinite(values) + if np.sum(finite) < 2: + dropped.append(self._predictor_label(key)) + continue + mean = float(np.nanmean(values[finite])) + std = float(np.nanstd(values[finite])) + if not np.isfinite(std) or std <= 1e-12: + dropped.append(self._predictor_label(key)) + continue + values[~finite] = mean + values = (values - mean) / std + term = self._safe_design_name(self._predictor_label(key), used_names) + design[term] = values + terms.append(term) + term_labels[term] = self._predictor_label(key) + return design, terms, dropped, term_labels + # ------------------------------------------------------------------ # Slots # ------------------------------------------------------------------ @@ -1916,7 +2046,81 @@ def _plot_glm_fit(self, result: GLMResult): # FLMM fit # ------------------------------------------------------------------ + def _fit_flmm_group(self) -> None: + if not self._flmm.available: + self.statusMessage.emit( + "R + fastFMM not available. Please install R, rpy2, and the fastFMM R package.", 8000 + ) + return + + mat, tvec, row_labels, scope = self._flmm_matrix_and_labels() + if mat is None or tvec is None: + self.statusMessage.emit("No PSTH matrix - compute PSTH first.", 5000) + return + if mat.ndim != 2 or mat.shape[0] < 2: + self.statusMessage.emit("Need at least 2 animals/trials for FLMM.", 5000) + return + + n_rows = int(mat.shape[0]) + group_var = self.edit_group_var.text().strip() or "subject" + if not re.match(r"^[A-Za-z_][0-9A-Za-z_]*$", group_var): + group_var = "subject" + self.edit_group_var.setText(group_var) + if len(row_labels) != n_rows: + row_labels = [f"{scope}_{i + 1}" for i in range(n_rows)] + + if scope == "animal": + design, auto_terms, dropped_terms, term_labels = self._build_flmm_design(row_labels, group_var) + else: + design = {group_var: np.asarray(row_labels, dtype=object)} + auto_terms = [] + dropped_terms = [] + term_labels = {} + + requested_formula = self.edit_formula.text().strip() + if requested_formula in {"", "Y.obs ~ 1", "Y.obs ~ group"}: + formula = "Y.obs ~ " + " + ".join(auto_terms) if auto_terms else "Y.obs ~ 1" + if requested_formula != formula: + self.edit_formula.setText(formula) + else: + formula = requested_formula + random_eff = self.edit_random.text().strip() or "~1" + nknots = self.spin_nknots.value() if self.spin_nknots.value() > 0 else None + num_boots = self.spin_boots.value() + + self.statusMessage.emit("Fitting FLMM via fastFMM - this may take a while...", 0) + QtWidgets.QApplication.processEvents() + + result = self._flmm.fit( + mat, tvec, design, + formula_fixed=formula, + random_effects=random_eff, + group_var=group_var, + nknots_min=nknots, + num_boots=num_boots, + ) + self._flmm_result = result + + summary_lines = [ + result.summary_text, + "", + f"Scope: {'animal/group rows' if scope == 'animal' else 'trial rows'}", + f"Rows: {n_rows}", + f"ID variable: {group_var}", + f"Formula: {formula}", + ] + if auto_terms: + readable = [f"{term} = {term_labels.get(term, term)}" for term in auto_terms] + summary_lines.append("Animal covariates: " + "; ".join(readable)) + if dropped_terms: + summary_lines.append("Dropped covariates: " + ", ".join(dropped_terms)) + self.txt_summary.setPlainText("\n".join(summary_lines)) + self._plot_flmm_coefficients(result) + self.statusMessage.emit("FLMM fit complete.", 5000) + def _fit_flmm(self): + self._fit_flmm_group() + return if not self._flmm.available: self.statusMessage.emit( "R + fastFMM not available. Please install R, rpy2, and the fastFMM R package.", 8000 From 9b080332dd226183c51990cade607b9734e763b4 Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 11:57:32 +0200 Subject: [PATCH 05/14] Add configurable artifact handling modes --- pyBer/analysis_core.py | 109 ++++++++++++++++++++++++++++++++++--- pyBer/gui_preprocessing.py | 22 +++++++- pyBer/main.py | 1 + 3 files changed, 121 insertions(+), 11 deletions(-) diff --git a/pyBer/analysis_core.py b/pyBer/analysis_core.py index 7205ada..edea563 100644 --- a/pyBer/analysis_core.py +++ b/pyBer/analysis_core.py @@ -50,6 +50,13 @@ "Moving median", ] +ARTIFACT_HANDLING_MODES = [ + "Interpolate", + "Cut", + "Strong local low-pass", + "Do nothing", +] + # Output modes OUTPUT_MODES = [ # 1) dFF (non motion corrected) @@ -80,6 +87,7 @@ class ProcessingParams: # ------------------------- artifact_detection_enabled: bool = True artifact_mode: str = "Global MAD (dx)" # or "Adaptive MAD (windowed)" + artifact_handling: str = "Interpolate" mad_k: float = 8.0 adaptive_window_s: float = 5.0 artifact_pad_s: float = 0.25 @@ -670,6 +678,82 @@ def _lowpass_sos(x: np.ndarray, fs: float, cutoff: float, order: int) -> np.ndar return np.asarray(sosfiltfilt(sos, y), float) +def _normalize_artifact_handling(value: object) -> str: + text = str(value or "").strip().lower() + if text.startswith("cut"): + return "Cut" + if "low" in text and "pass" in text: + return "Strong local low-pass" + if text.startswith("do") or text in {"none", "nothing", "off", "ignore"}: + return "Do nothing" + return "Interpolate" + + +def _strong_local_lowpass_artifacts( + x: np.ndarray, + mask: np.ndarray, + fs: float, + base_cutoff_hz: float, + filter_order: int, +) -> np.ndarray: + y = np.asarray(x, float).copy() + m = np.asarray(mask, bool) + if y.size == 0 or m.size != y.size or not np.any(m): + return y + bridged = y.copy() + bridged[m] = np.nan + bridged = interpolate_nans(bridged) + if np.any(~np.isfinite(bridged)): + return y + + # Use a clearly stronger local cutoff than the main anti-aliasing filter. + try: + base_cutoff = float(base_cutoff_hz) + except Exception: + base_cutoff = 12.0 + if not np.isfinite(base_cutoff) or base_cutoff <= 0: + base_cutoff = 12.0 + cutoff = min(2.0, max(0.05, 0.25 * base_cutoff)) + order = int(max(3, min(6, int(filter_order) + 1))) + try: + replacement = _lowpass_sos(bridged, fs, cutoff, order) + except Exception: + replacement = bridged + y[m] = replacement[m] + return y + + +def _apply_artifact_handling( + sig: np.ndarray, + ref: np.ndarray, + mask: np.ndarray, + fs: float, + params: ProcessingParams, +) -> Tuple[np.ndarray, np.ndarray, str]: + handling = _normalize_artifact_handling(getattr(params, "artifact_handling", "Interpolate")) + sig_corr = np.asarray(sig, float).copy() + ref_corr = np.asarray(ref, float).copy() + m = np.asarray(mask, bool) + if sig_corr.size == 0 or ref_corr.size == 0 or m.size != sig_corr.size or not np.any(m): + return sig_corr, ref_corr, handling + + if handling == "Do nothing" or handling == "Cut": + return sig_corr, ref_corr, handling + + if handling == "Strong local low-pass": + cutoff = float(getattr(params, "lowpass_hz", 12.0)) + order = int(getattr(params, "filter_order", 3)) + return ( + _strong_local_lowpass_artifacts(sig_corr, m, fs, cutoff, order), + _strong_local_lowpass_artifacts(ref_corr, m, fs, cutoff, order), + handling, + ) + + sig_corr[m] = np.nan + ref_corr[m] = np.nan + return interpolate_nans(sig_corr), interpolate_nans(ref_corr), handling + + def _window_samples_from_seconds( fs: float, window_s: float, @@ -1480,14 +1564,11 @@ def process_trial( mask = apply_manual_regions(t, mask, manual_regions_sec or []) # --------------------------------------------------------------------- - # 4) Mask artifacts (set NaN) then interpolate (keeps timebase intact) + # 4) Apply selected artifact handling. + # Interpolate is the historical default and keeps the timebase intact. + # Cut is applied after resampling so all processed arrays stay aligned. # --------------------------------------------------------------------- - sig_corr = sig.copy() - ref_corr = ref.copy() - sig_corr[mask] = np.nan - ref_corr[mask] = np.nan - sig_corr = interpolate_nans(sig_corr) - ref_corr = interpolate_nans(ref_corr) + sig_corr, ref_corr, artifact_handling = _apply_artifact_handling(sig, ref, mask, fs, params) # --------------------------------------------------------------------- # 5) Low-pass filter before decimation (anti-aliasing) @@ -1510,6 +1591,16 @@ def process_trial( # Resample the envelope for display (same timebase as processed) _, hi2, lo2, _ = _resample_pair_to_target_fs(t, hi_raw, lo_raw, fs, target_fs) + if artifact_handling == "Cut" and np.any(mask): + mask2 = np.interp(t2, t, mask.astype(float), left=0.0, right=0.0) > 0.5 + keep = ~mask2 + if np.sum(keep) >= 3: + t2 = t2[keep] + sig2 = sig2[keep] + ref2 = ref2[keep] + hi2 = hi2[keep] + lo2 = lo2[keep] + # --------------------------------------------------------------------- # 7) A/D overlay (if present): interpolate and binarize # --------------------------------------------------------------------- @@ -1612,8 +1703,10 @@ def process_trial( out = dff_sig context_parts = [] + if np.any(mask): + context_parts.append(f"Artifacts: {artifact_handling}") if mode == "Raw signal (465)": - context_parts.append("Raw 465 after artifact interpolation, filtering, and resampling") + context_parts.append("Raw 465 after artifact handling, filtering, and resampling") else: baseline_desc = f"Baseline: {params.baseline_method} (lambda={float(params.baseline_lambda):.2e})" if mode in ( diff --git a/pyBer/gui_preprocessing.py b/pyBer/gui_preprocessing.py index aa891cd..d845f42 100644 --- a/pyBer/gui_preprocessing.py +++ b/pyBer/gui_preprocessing.py @@ -17,6 +17,7 @@ BASELINE_METHODS, REFERENCE_FIT_METHODS, SMOOTHING_METHODS, + ARTIFACT_HANDLING_MODES, ) @@ -1345,6 +1346,13 @@ def _build_help_texts(self) -> Dict[str, str]: "- Global MAD: single threshold for the full trace (fast, stable).\n" "- Adaptive MAD: threshold computed in sliding windows (handles drift)." ), + "artifact_handling": ( + "How detected/manual artifact windows affect processing:\n" + "- Interpolate: replace windows by linear interpolation before filtering.\n" + "- Cut: remove artifact samples from the final processed trace.\n" + "- Strong local low-pass: replace only artifact-window samples with a strongly smoothed trace.\n" + "- Do nothing: keep the samples unchanged; overlays still show detected artifacts." + ), "mad_k": ( "MAD threshold (k) scales the derivative threshold.\n" "Higher k = fewer artifacts flagged; lower k = more sensitive." @@ -1521,6 +1529,9 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: self.combo_artifact = QtWidgets.QComboBox() self.combo_artifact.addItems(["Global MAD (dx)", "Adaptive MAD (windowed)"]) _compact_combo(self.combo_artifact, min_chars=6) + self.combo_artifact_handling = QtWidgets.QComboBox() + self.combo_artifact_handling.addItems(ARTIFACT_HANDLING_MODES) + _compact_combo(self.combo_artifact_handling, min_chars=8) self.spin_mad = mk_dspin() self.spin_mad.setRange(1.0, 50.0) self.spin_mad.setValue(8.0) @@ -1537,6 +1548,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox: art_form.addRow(self.cb_artifact) art_form.addRow(self.cb_show_artifact_overlay) art_form.addRow(self._label_with_help("Method", "artifact_mode"), self.combo_artifact) + art_form.addRow(self._label_with_help("Handling", "artifact_handling"), self.combo_artifact_handling) art_form.addRow(self._label_with_help("MAD threshold (k)", "mad_k"), self.spin_mad) art_form.addRow(self._label_with_help("Adaptive window (s)", "adaptive_window_s"), self.spin_adapt_win) art_form.addRow(self._label_with_help("Artifact pad (s)", "artifact_pad_s"), self.spin_pad) @@ -1955,6 +1967,7 @@ def _fmt_num(self, value: float, decimals: int = 3) -> str: return text if text else "0" def _update_section_summaries(self) -> None: + handling = self.combo_artifact_handling.currentText() if self.cb_artifact.isChecked(): mode = self.combo_artifact.currentText() method = "Adaptive MAD" if mode.startswith("Adaptive") else "Global MAD" @@ -1962,15 +1975,15 @@ def _update_section_summaries(self) -> None: summary = ( f"{method}, k={self._fmt_num(self.spin_mad.value(), 2)}, " f"window={self._fmt_num(self.spin_adapt_win.value(), 2)}s, " - f"pad={self._fmt_num(self.spin_pad.value(), 2)}s" + f"pad={self._fmt_num(self.spin_pad.value(), 2)}s, {handling}" ) else: summary = ( f"{method}, k={self._fmt_num(self.spin_mad.value(), 2)}, " - f"pad={self._fmt_num(self.spin_pad.value(), 2)}s" + f"pad={self._fmt_num(self.spin_pad.value(), 2)}s, {handling}" ) else: - summary = "Off" + summary = f"Detection off, {handling}" self.card_artifacts.set_summary(summary) if self.cb_filtering.isChecked(): @@ -2012,6 +2025,7 @@ def emit_noargs(*_args) -> None: widgets = ( self.combo_artifact, + self.combo_artifact_handling, self.spin_mad, self.spin_adapt_win, self.spin_pad, @@ -2324,6 +2338,7 @@ def get_params(self) -> ProcessingParams: return ProcessingParams( artifact_detection_enabled=self.cb_artifact.isChecked(), artifact_mode=self.combo_artifact.currentText(), + artifact_handling=self.combo_artifact_handling.currentText(), mad_k=float(self.spin_mad.value()), adaptive_window_s=float(self.spin_adapt_win.value()), artifact_pad_s=float(self.spin_pad.value()), @@ -2359,6 +2374,7 @@ def set_params(self, params: ProcessingParams) -> None: return self.cb_artifact.setChecked(bool(getattr(params, "artifact_detection_enabled", True))) self.combo_artifact.setCurrentText(str(params.artifact_mode)) + self.combo_artifact_handling.setCurrentText(str(getattr(params, "artifact_handling", "Interpolate"))) self.spin_mad.setValue(float(params.mad_k)) self.spin_adapt_win.setValue(float(params.adaptive_window_s)) self.spin_pad.setValue(float(params.artifact_pad_s)) diff --git a/pyBer/main.py b/pyBer/main.py index c009218..cdfaa12 100644 --- a/pyBer/main.py +++ b/pyBer/main.py @@ -5132,6 +5132,7 @@ def _artifact_param_signature(self, params: ProcessingParams) -> Tuple[object, . return ( bool(getattr(params, "artifact_detection_enabled", True)), str(params.artifact_mode), + str(getattr(params, "artifact_handling", "Interpolate")), float(params.mad_k), float(params.adaptive_window_s), float(params.artifact_pad_s), From 018a51f063a4083fe42440bafc2b770fc509072e Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 13:09:30 +0200 Subject: [PATCH 06/14] Add temporal model stats and FLMM support --- environment.yml | 6 + pyBer.spec | 7 +- pyBer/temporal_modeling.py | 674 +++++++++++++++++++++++++++++++++---- 3 files changed, 622 insertions(+), 65 deletions(-) diff --git a/environment.yml b/environment.yml index bf2d4c2..ddefbec 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,12 @@ dependencies: - pyqtgraph>=0.13 - matplotlib>=3.8 + # FLMM backend via R fastFMM/rpy2 + - rpy2>=3.6 + # Windows R/rpy2 startup uses sh/make when validating R's config. + - m2-base + - m2-make + # Quality-of-life / packaging - pip - setuptools diff --git a/pyBer.spec b/pyBer.spec index 4c50bed..adbfc2e 100644 --- a/pyBer.spec +++ b/pyBer.spec @@ -39,7 +39,12 @@ a = Analysis( 'freetype.dll', ]), datas=[('assets/pyBer_logo_big.png', 'assets'), ('assets/pyBer.ico', 'assets')], - hiddenimports=[], + hiddenimports=[ + 'rpy2.rinterface', + 'rpy2.rinterface_lib', + 'rpy2.robjects', + 'rpy2.robjects.packages', + ], hookspath=[], hooksconfig={}, runtime_hooks=[], diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index 267c2f7..de971fe 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -16,6 +16,7 @@ import logging import os import re +import sys import traceback from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -44,34 +45,70 @@ def _init_r(): global _R_READY if _R_READY: return - # On Windows, R is often not on PATH. Try the standard install location. - r_home = os.environ.get("R_HOME", "") - if not r_home: - candidate = "C:/Program Files/R" - if os.path.isdir(candidate): - subs = sorted(os.listdir(candidate), reverse=True) - if subs: - r_home = os.path.join(candidate, subs[0]) + + def _candidate_r_homes() -> List[str]: + roots: List[str] = [] + for prefix in (os.environ.get("CONDA_PREFIX", ""), sys.prefix, os.path.dirname(sys.executable)): + if not prefix: + continue + roots.extend([ + os.path.join(prefix, "Lib", "R"), + os.path.join(prefix, "lib", "R"), + os.path.join(prefix, "R"), + ]) + program_files = "C:/Program Files/R" + if os.path.isdir(program_files): + subs = sorted(os.listdir(program_files), reverse=True) + roots.extend(os.path.join(program_files, sub) for sub in subs) + return roots + + # On Windows, R is often not on PATH. Try active envs first, then Program Files. + r_home = os.environ.get("R_HOME", "").strip() + if not r_home or not os.path.isdir(r_home): + for candidate in _candidate_r_homes(): + if os.path.isdir(candidate): + r_home = candidate os.environ["R_HOME"] = r_home + break if r_home: - bin_x64 = os.path.join(r_home, "bin", "x64") - if os.path.isdir(bin_x64): + r_bins = [os.path.join(r_home, "bin", "x64"), os.path.join(r_home, "bin")] + for bin_dir in r_bins: + if not os.path.isdir(bin_dir): + continue try: - os.add_dll_directory(bin_x64) + os.add_dll_directory(bin_dir) except (OSError, AttributeError): pass - # Ensure PATH includes R so child processes find R.dll + # Ensure PATH includes R so child processes find R.dll. cur_path = os.environ.get("PATH", "") - if bin_x64 not in cur_path: - os.environ["PATH"] = bin_x64 + os.pathsep + cur_path + for bin_dir in reversed(r_bins): + if os.path.isdir(bin_dir) and bin_dir not in cur_path: + cur_path = bin_dir + os.pathsep + cur_path + os.environ["PATH"] = cur_path + + # rpy2 may call R's config.sh on Windows; make sh/make visible if the env has them. + tool_bins: List[str] = [] + for prefix in (os.environ.get("CONDA_PREFIX", ""), sys.prefix): + if prefix: + tool_bins.append(os.path.join(prefix, "Library", "usr", "bin")) + cur_path = os.environ.get("PATH", "") + for tool_bin in reversed(tool_bins): + if os.path.isdir(tool_bin) and tool_bin not in cur_path: + cur_path = tool_bin + os.pathsep + cur_path + os.environ["PATH"] = cur_path r_libs_user = os.environ.get("R_LIBS_USER", "") if not r_libs_user: - candidate_lib = os.path.expanduser("~/R/win-library") - if os.path.isdir(candidate_lib): - subs = sorted(os.listdir(candidate_lib), reverse=True) - if subs: - os.environ["R_LIBS_USER"] = os.path.join(candidate_lib, subs[0]) + candidate_roots = [ + os.path.join(os.environ.get("LOCALAPPDATA", ""), "R", "win-library"), + os.path.expanduser("~/R/win-library"), + ] + for candidate_lib in candidate_roots: + if os.path.isdir(candidate_lib): + subs = sorted(os.listdir(candidate_lib), reverse=True) + if subs: + os.environ["R_LIBS_USER"] = os.path.join(candidate_lib, subs[0]) + break import rpy2.robjects as ro # noqa: F811 from rpy2.robjects import r as R # noqa: F811 @@ -151,6 +188,8 @@ class GLMResult: r2: float coefficients: np.ndarray # raw beta vector design_matrix: np.ndarray + stats: Dict[str, float] = field(default_factory=dict) + feature_importance: List[Dict[str, Any]] = field(default_factory=list) class ContinuousGLM: @@ -354,6 +393,25 @@ def fit( ss_res = np.nansum(residuals ** 2) ss_tot = np.nansum((signal - np.nanmean(signal)) ** 2) r2 = 1.0 - ss_res / max(ss_tot, 1e-12) + valid_fit = np.isfinite(signal) & np.isfinite(y_pred) + res_fit = residuals[valid_fit] + y_fit = signal[valid_fit] + pred_fit = y_pred[valid_fit] + mse = float(np.nanmean(res_fit ** 2)) if res_fit.size else float("nan") + rmse = float(np.sqrt(mse)) if np.isfinite(mse) else float("nan") + mae = float(np.nanmean(np.abs(res_fit))) if res_fit.size else float("nan") + resid_std = float(np.nanstd(res_fit)) if res_fit.size else float("nan") + corr = float("nan") + if y_fit.size > 2 and np.nanstd(y_fit) > 1e-12 and np.nanstd(pred_fit) > 1e-12: + corr = float(np.corrcoef(y_fit, pred_fit)[0, 1]) + stats = { + "n_samples": float(np.sum(valid_fit)), + "mse": mse, + "rmse": rmse, + "mae": mae, + "residual_std": resid_std, + "corr": corr, + } # Extract kernels dt = np.median(np.diff(time)) @@ -387,6 +445,7 @@ def fit( r2=r2, coefficients=beta, design_matrix=X, + stats=stats, ) return self._result @@ -414,6 +473,8 @@ class FLMMResult: residuals: Optional[np.ndarray] = None aic: Optional[float] = None summary_text: str = "" + stats: Dict[str, float] = field(default_factory=dict) + feature_importance: List[Dict[str, Any]] = field(default_factory=list) class TrialFLMM: @@ -422,8 +483,8 @@ class TrialFLMM: Wraps fastFMM::fui() via rpy2. The user provides a trial-level data matrix (n_trials x n_timepoints) plus a design dataframe (n_trials rows) - with fixed/random predictors. The backend constructs the long-form - data and calls fui(). + with fixed/random predictors. The backend passes the trace matrix as the + functional response expected by fastFMM. """ def __init__(self): @@ -470,51 +531,72 @@ def fit( """ _init_r() import rpy2.robjects as ro - from rpy2.robjects import r as R, pandas2ri, numpy2ri + from rpy2.robjects import r as R from rpy2.robjects.packages import importr n_trials, n_time = mat.shape - # Build long-form dataframe in R - # Columns: Y.obs, .index (timepoint), .obs (trial id), + design vars - Y_long = mat.ravel(order="C") # trial-major - index_long = np.tile(np.arange(n_time), n_trials) - obs_long = np.repeat(np.arange(n_trials), n_time) - - r_df_vars = { - "Y.obs": ro.FloatVector(Y_long), - ".index": ro.IntVector(index_long.astype(int)), - ".obs": ro.IntVector(obs_long.astype(int)), - } + if group_var not in design: + raise ValueError(f"FLMM design is missing group variable '{group_var}'.") + r_df_vars = {} for col_name, col_vals in design.items(): col_vals = np.asarray(col_vals) - repeated = np.repeat(col_vals, n_time) + if col_vals.size != n_trials: + raise ValueError(f"FLMM design column '{col_name}' has {col_vals.size} values for {n_trials} rows.") if np.issubdtype(col_vals.dtype, np.floating): - r_df_vars[col_name] = ro.FloatVector(repeated) + r_df_vars[col_name] = ro.FloatVector(col_vals.astype(float)) elif np.issubdtype(col_vals.dtype, np.integer): - r_df_vars[col_name] = ro.IntVector(repeated.astype(int)) + r_df_vars[col_name] = ro.IntVector(col_vals.astype(int)) else: - r_df_vars[col_name] = ro.StrVector(repeated.astype(str)) + r_df_vars[col_name] = ro.FactorVector(ro.StrVector(col_vals.astype(str))) r_df = ro.DataFrame(r_df_vars) + y_matrix = R.matrix(ro.FloatVector(np.asarray(mat, float).ravel(order="F")), nrow=n_trials, ncol=n_time) + ro.globalenv[".__pyber_flmm_df"] = r_df + ro.globalenv[".__pyber_flmm_y"] = y_matrix + r_df = R(".__pyber_flmm_df$Y.obs <- I(.__pyber_flmm_y); .__pyber_flmm_df") + + formula_text = str(formula_fixed or "Y.obs ~ 1").strip() or "Y.obs ~ 1" + group_vals = np.asarray(design[group_var]).astype(str) + has_repeated_groups = np.unique(group_vals).size < group_vals.size + if "|" not in formula_text: + rand = str(random_effects or "~1").strip() + rand_rhs = rand.split("~", 1)[1].strip() if "~" in rand else rand + if rand_rhs.lower() in {"", "0", "none", "fixed"}: + rand_rhs = "" + if not has_repeated_groups: + raise ValueError( + "FLMM requires repeated rows per subject for random effects. " + "Compute PSTH from per-file trials instead of animal-averaged rows." + ) + if rand_rhs: + formula_text = f"{formula_text} + ({rand_rhs} | {group_var})" # Call fui() fastFMM = importr("fastFMM") kwargs = { - "formula": ro.Formula(formula_fixed), + "formula": ro.Formula(formula_text), "data": r_df, - "id": ro.StrVector([group_var]), - "G": ro.Formula(random_effects), "parallel": ro.BoolVector([parallel]), + "silent": ro.BoolVector([True]), + "subj_id": ro.StrVector([group_var]), + "override_zero_var": ro.BoolVector([True]), } if nknots_min is not None: kwargs["nknots_min"] = ro.IntVector([nknots_min]) + if n_time < 35: + kwargs["nknots_min_cov"] = ro.IntVector([max(4, min(35, max(4, n_time // 2)))]) if num_boots > 0: - kwargs["num_boots"] = ro.IntVector([num_boots]) + kwargs["analytic"] = ro.BoolVector([False]) + kwargs["n_boots"] = ro.IntVector([num_boots]) + kwargs["argvals"] = ro.FloatVector(np.asarray(tvec, float)) + else: + kwargs["analytic"] = ro.BoolVector([True]) + kwargs["n_boots"] = ro.IntVector([0]) _LOG.info("Calling fastFMM::fui() with formula=%s, %d trials, %d timepoints", - formula_fixed, n_trials, n_time) + formula_text, n_trials, n_time) fui_result = fastFMM.fui(**kwargs) @@ -531,9 +613,11 @@ def fit( joint_ci_upper: Dict[str, np.ndarray] = {} try: - beta_hat = np.array(R('as.matrix')(fui_result.rx2("betaHat"))) - beta_lb = np.array(R('as.matrix')(fui_result.rx2("betaHat.LB"))) - beta_ub = np.array(R('as.matrix')(fui_result.rx2("betaHat.UB"))) + try: + result_names = set(str(name) for name in R('names')(fui_result)) + except Exception: + result_names = set() + beta_hat = np.atleast_2d(np.array(R('as.matrix')(fui_result.rx2("betaHat")), dtype=float)) # Term names from rownames try: @@ -541,6 +625,34 @@ def fit( except Exception: term_names = [f"term_{i}" for i in range(beta_hat.shape[0])] + se_mat = None + try: + if result_names and "se_mat" not in result_names: + raise KeyError("se_mat") + se_mat = np.atleast_2d(np.array(R('as.matrix')(fui_result.rx2("se_mat")), dtype=float)) + if se_mat.shape != beta_hat.shape: + se_mat = None + except Exception: + se_mat = None + + if se_mat is not None: + beta_lb = beta_hat - 1.96 * se_mat + beta_ub = beta_hat + 1.96 * se_mat + else: + try: + if result_names and ("betaHat.LB" not in result_names or "betaHat.UB" not in result_names): + raise KeyError("betaHat CI") + beta_lb = np.atleast_2d(np.array(R('as.matrix')(fui_result.rx2("betaHat.LB")), dtype=float)) + beta_ub = np.atleast_2d(np.array(R('as.matrix')(fui_result.rx2("betaHat.UB")), dtype=float)) + except Exception: + beta_lb = beta_hat.copy() + beta_ub = beta_hat.copy() + + try: + qn = np.asarray(fui_result.rx2("qn"), float).ravel() + except Exception: + qn = np.array([], float) + for i, name in enumerate(term_names): coefficients[name] = beta_hat[i, :] ci_lower[name] = beta_lb[i, :] @@ -548,27 +660,58 @@ def fit( # Joint CIs (may not always be present) try: + if result_names and ("betaHat.LB.joint" not in result_names or "betaHat.UB.joint" not in result_names): + raise KeyError("joint CI") jlb = np.array(R('as.matrix')(fui_result.rx2("betaHat.LB.joint"))) jub = np.array(R('as.matrix')(fui_result.rx2("betaHat.UB.joint"))) for i, name in enumerate(term_names): joint_ci_lower[name] = jlb[i, :] joint_ci_upper[name] = jub[i, :] except Exception: - joint_ci_lower = {k: v.copy() for k, v in ci_lower.items()} - joint_ci_upper = {k: v.copy() for k, v in ci_upper.items()} + if se_mat is not None and qn.size: + for i, name in enumerate(term_names): + qcrit = float(qn[min(i, qn.size - 1)]) + joint_ci_lower[name] = beta_hat[i, :] - qcrit * se_mat[i, :] + joint_ci_upper[name] = beta_hat[i, :] + qcrit * se_mat[i, :] + else: + joint_ci_lower = {k: v.copy() for k, v in ci_lower.items()} + joint_ci_upper = {k: v.copy() for k, v in ci_upper.items()} aic_val = None - try: - aic_val = float(np.array(fui_result.rx2("AIC"))[0]) - except Exception: - pass + for key in ("aic", "AIC"): + try: + aic_arr = np.array(R('as.matrix')(fui_result.rx2(key)), dtype=float) + if aic_arr.size: + if aic_arr.ndim == 2 and aic_arr.shape[1] >= 1: + aic_val = float(np.nanmean(aic_arr[:, 0])) + else: + aic_val = float(np.nanmean(aic_arr)) + break + except Exception: + continue summary_parts = [f"FLMM fit: {len(term_names)} terms, {n_trials} trials, {n_time} timepoints"] if aic_val is not None: summary_parts.append(f"AIC = {aic_val:.1f}") + coeff_abs_peaks: List[float] = [] + coeff_abs_means: List[float] = [] for name in term_names: - summary_parts.append(f" {name}: mean coef = {np.nanmean(coefficients[name]):.4f}") + coeff = np.asarray(coefficients[name], float) + coeff_abs_peaks.append(float(np.nanmax(np.abs(coeff))) if coeff.size else float("nan")) + coeff_abs_means.append(float(np.nanmean(np.abs(coeff))) if coeff.size else float("nan")) + summary_parts.append( + f" {name}: mean coef = {np.nanmean(coeff):.4f}, " + f"mean abs = {np.nanmean(np.abs(coeff)):.4f}, peak abs = {np.nanmax(np.abs(coeff)):.4f}" + ) summary_text = "\n".join(summary_parts) + stats = { + "n_trials": float(n_trials), + "n_timepoints": float(n_time), + "n_terms": float(len(term_names)), + "aic": float(aic_val) if aic_val is not None else float("nan"), + "mean_abs_coefficient": float(np.nanmean(coeff_abs_means)) if coeff_abs_means else float("nan"), + "peak_abs_coefficient": float(np.nanmax(coeff_abs_peaks)) if coeff_abs_peaks else float("nan"), + } except Exception as exc: _LOG.error("Failed to parse fui() result: %s", exc) @@ -583,6 +726,7 @@ def fit( joint_ci_upper=joint_ci_upper, aic=aic_val, summary_text=summary_text, + stats=stats, ) return self._result @@ -740,6 +884,7 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._group_mat: Optional[np.ndarray] = None self._group_tvec: Optional[np.ndarray] = None self._group_labels: List[str] = [] + self._flmm_row_meta: List[Dict[str, Any]] = [] self._visual_mode: int = 0 self._group_mode: bool = False @@ -1188,6 +1333,14 @@ def _build_workspace_pages(self): residual_lay.addWidget(self.plot_residuals, 1) self.tabs_workspace.addTab(residual_page, "Residuals") + importance_page = QtWidgets.QWidget() + importance_lay = QtWidgets.QVBoxLayout(importance_page) + importance_lay.setContentsMargins(10, 10, 10, 10) + self.plot_importance = pg.PlotWidget(title="Feature contribution") + self._style_plot(self.plot_importance) + importance_lay.addWidget(self.plot_importance, 1) + self.tabs_workspace.addTab(importance_page, "Importance") + flmm_page = QtWidgets.QWidget() flmm_lay = QtWidgets.QVBoxLayout(flmm_page) flmm_lay.setContentsMargins(10, 10, 10, 10) @@ -1735,6 +1888,47 @@ def _safe_design_name(label: str, used: set[str]) -> str: return name def _flmm_matrix_and_labels(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[str], str]: + self._flmm_row_meta = [] + if self._per_file_mats: + mats: List[np.ndarray] = [] + labels: List[str] = [] + meta_rows: List[Dict[str, Any]] = [] + ref_tvec: Optional[np.ndarray] = None + event_by_file: Dict[str, List[Dict[str, object]]] = {} + for row in self._event_rows: + fid = str(row.get("file_id", "") or "") + if fid: + event_by_file.setdefault(fid, []).append(row) + ordered_ids = list(self._file_ids) if self._file_ids else list(self._per_file_mats.keys()) + for file_id in ordered_ids: + if file_id not in self._per_file_mats: + continue + tvec_f, mat_f = self._per_file_mats[file_id] + tvec_f = np.asarray(tvec_f, float) + mat_f = np.asarray(mat_f, float) + if mat_f.ndim != 2 or mat_f.shape[0] == 0 or tvec_f.size != mat_f.shape[1]: + continue + if ref_tvec is None: + ref_tvec = tvec_f + elif ref_tvec.size != tvec_f.size or not np.allclose(ref_tvec, tvec_f, equal_nan=True): + continue + mats.append(mat_f) + rows = event_by_file.get(file_id, []) + for row_idx in range(mat_f.shape[0]): + labels.append(file_id) + ev_row = rows[row_idx] if row_idx < len(rows) else {} + meta_rows.append({ + "file_id": file_id, + "trial_index": row_idx, + "event_time_sec": ev_row.get("event_time_sec", np.nan), + "duration_sec": ev_row.get("duration_sec", np.nan), + }) + if mats and ref_tvec is not None: + mat = np.vstack(mats) + if mat.shape[0] >= 2 and len(labels) == mat.shape[0]: + self._flmm_row_meta = meta_rows + return mat, ref_tvec, labels, "trial" + group_mat = np.asarray(self._group_mat, float) if self._group_mat is not None else None group_tvec = np.asarray(self._group_tvec, float) if self._group_tvec is not None else None if ( @@ -1744,6 +1938,7 @@ def _flmm_matrix_and_labels(self) -> Tuple[Optional[np.ndarray], Optional[np.nda and group_mat.shape[0] >= 2 and len(self._group_labels) == group_mat.shape[0] ): + self._flmm_row_meta = [{"file_id": label, "trial_index": i} for i, label in enumerate(self._group_labels)] return group_mat, group_tvec, list(self._group_labels), "animal" mat = np.asarray(self._psth_mat, float) if self._psth_mat is not None else None @@ -1752,10 +1947,11 @@ def _flmm_matrix_and_labels(self) -> Tuple[Optional[np.ndarray], Optional[np.nda return None, None, [], "none" if len(self._file_ids) == mat.shape[0]: labels = list(self._file_ids) - scope = "animal" + scope = "trial" else: labels = [f"trial_{i + 1}" for i in range(mat.shape[0])] scope = "trial" + self._flmm_row_meta = [{"file_id": labels[i] if i < len(labels) else "", "trial_index": i} for i in range(mat.shape[0])] return mat, tvec, labels, scope def _animal_covariate_value(self, key: str, file_id: str) -> float: @@ -1791,6 +1987,50 @@ def _animal_covariate_value(self, key: str, file_id: str) -> float: return float(np.nanmean(finite)) return np.nan + def _trial_covariate_value(self, key: str, meta: Dict[str, Any]) -> float: + file_id = str(meta.get("file_id", "") or "") + event_time = meta.get("event_time_sec", np.nan) + try: + event_time = float(event_time) + except (TypeError, ValueError): + event_time = np.nan + proc = self._proc_for_file_id(file_id) + if proc is None or not np.isfinite(event_time): + return self._animal_covariate_value(key, file_id) + t = np.asarray(getattr(proc, "time", np.array([], float)), float) + if t.size < 2: + return self._animal_covariate_value(key, file_id) + + if key in {"events", "dio"} or key.startswith("trigger::") or key.startswith("behavior_event::"): + vec, _ = self._predictor_vector_for_proc(key, proc, t) + edges = t[np.asarray(vec, float) > 0.5] + if edges.size == 0: + return 0.0 + dt = float(np.nanmedian(np.diff(np.sort(t)))) + tol = max(dt if np.isfinite(dt) and dt > 0 else 1e-3, 1e-3) + return float(np.any(np.abs(edges - event_time) <= tol)) + + entry = self._predictor_catalog.get(key, {}) + name = str(entry.get("name", "") or key.split("::")[-1]) + info = self._behavior_source_for_proc(proc) + if not isinstance(info, dict): + return self._animal_covariate_value(key, file_id) + if key.startswith("behavior_state::"): + behaviors = info.get("behaviors") or {} + values = behaviors.get(name) + source_time = np.asarray(info.get("time", np.array([], float)), float) + if values is None or source_time.size == 0: + return np.nan + return float(self._interp_to_time(np.array([event_time], float), source_time, values)[0]) + if key.startswith("behavior_cont::"): + trajectory = info.get("trajectory") or {} + values = trajectory.get(name) + source_time = np.asarray(info.get("trajectory_time", np.array([], float)), float) + if values is None or source_time.size == 0: + return np.nan + return float(self._interp_to_time(np.array([event_time], float), source_time, values)[0]) + return self._animal_covariate_value(key, file_id) + def _build_flmm_design( self, labels: List[str], @@ -1803,7 +2043,13 @@ def _build_flmm_design( used_names = {group_var, "Y.obs", ".index", ".obs"} for key in self._selected_predictor_keys(): - values = np.asarray([self._animal_covariate_value(key, label) for label in labels], float) + if len(self._flmm_row_meta) == len(labels): + values = np.asarray([ + self._trial_covariate_value(key, self._flmm_row_meta[i]) + for i, _label in enumerate(labels) + ], float) + else: + values = np.asarray([self._animal_covariate_value(key, label) for label in labels], float) finite = np.isfinite(values) if np.sum(finite) < 2: dropped.append(self._predictor_label(key)) @@ -1821,6 +2067,172 @@ def _build_flmm_design( term_labels[term] = self._predictor_label(key) return design, terms, dropped, term_labels + @staticmethod + def _intercept_only_fit_stats(signal: np.ndarray) -> Dict[str, float]: + signal = np.asarray(signal, float) + valid = np.isfinite(signal) + if not np.any(valid): + return {"r2": float("nan"), "mse": float("nan")} + y = signal[valid] + mean = float(np.nanmean(y)) + residuals = y - mean + ss_res = float(np.nansum(residuals ** 2)) + ss_tot = float(np.nansum((y - np.nanmean(y)) ** 2)) + mse = float(np.nanmean(residuals ** 2)) if y.size else float("nan") + return {"r2": 1.0 - ss_res / max(ss_tot, 1e-12), "mse": mse} + + def _compute_glm_leave_one_out( + self, + dataset: Dict[str, Any], + result: GLMResult, + kernel_window: Tuple[float, float], + basis_type: str, + regularization: str, + alpha: float, + ) -> List[Dict[str, Any]]: + predictors = dict(dataset.get("predictors", {}) or {}) + if not predictors or not result.predictor_names: + return [] + full_mse = float(result.stats.get("mse", float("nan"))) + rows: List[Dict[str, Any]] = [] + time = np.asarray(dataset["time"], float) + signal = np.asarray(dataset["signal"], float) + for pred_name in result.predictor_names: + row: Dict[str, Any] = { + "feature": pred_name, + "label": self._predictor_label(pred_name), + "full_r2": float(result.r2), + "full_mse": full_mse, + "reduced_r2": float("nan"), + "reduced_mse": float("nan"), + "delta_r2": float("nan"), + "delta_mse": float("nan"), + "contribution_pct": float("nan"), + "status": "ok", + } + try: + reduced_predictors = {k: v for k, v in predictors.items() if k != pred_name} + if reduced_predictors: + reduced = ContinuousGLM().fit( + time, + signal, + reduced_predictors, + kernel_window=kernel_window, + n_basis=self.spin_n_basis.value(), + basis_type=basis_type, + regularization=regularization, + alpha=alpha, + ) + reduced_r2 = float(reduced.r2) + reduced_mse = float(reduced.stats.get("mse", float("nan"))) + else: + stats = self._intercept_only_fit_stats(signal) + reduced_r2 = float(stats["r2"]) + reduced_mse = float(stats["mse"]) + delta_r2 = float(result.r2 - reduced_r2) + delta_mse = float(reduced_mse - full_mse) if np.isfinite(reduced_mse) and np.isfinite(full_mse) else float("nan") + denom = float(result.r2) if np.isfinite(result.r2) and abs(result.r2) > 1e-12 else float("nan") + row.update({ + "reduced_r2": reduced_r2, + "reduced_mse": reduced_mse, + "delta_r2": delta_r2, + "delta_mse": delta_mse, + "contribution_pct": 100.0 * delta_r2 / denom if np.isfinite(denom) else float("nan"), + }) + except Exception as exc: + row["status"] = f"failed: {exc}" + rows.append(row) + QtWidgets.QApplication.processEvents() + rows.sort(key=lambda item: ( + np.isfinite(item.get("delta_r2", np.nan)), + float(item.get("delta_r2", -np.inf)) if np.isfinite(item.get("delta_r2", np.nan)) else -np.inf, + ), reverse=True) + return rows + + @staticmethod + def _simple_formula_terms(formula: str) -> List[str]: + if "~" not in str(formula): + return [] + rhs = str(formula).split("~", 1)[1].replace("\n", " ") + terms: List[str] = [] + for raw in rhs.split("+"): + term = raw.strip().strip("`") + if not term or term in {"0", "1", "-1"}: + continue + terms.append(term) + return terms + + @staticmethod + def _term_mean_abs_coefficient(result: FLMMResult, term: str) -> float: + if not result.coefficients: + return float("nan") + clean = re.sub(r"[^0-9A-Za-z_]+", "", str(term).lower()) + for name, coeff in result.coefficients.items(): + if str(name) == str(term): + vals = np.asarray(coeff, float) + return float(np.nanmean(np.abs(vals))) if vals.size else float("nan") + for name, coeff in result.coefficients.items(): + name_clean = re.sub(r"[^0-9A-Za-z_]+", "", str(name).lower()) + if clean and (clean in name_clean or name_clean in clean): + vals = np.asarray(coeff, float) + return float(np.nanmean(np.abs(vals))) if vals.size else float("nan") + return float("nan") + + def _compute_flmm_leave_one_out( + self, + mat: np.ndarray, + tvec: np.ndarray, + design: Dict[str, np.ndarray], + formula: str, + random_eff: str, + group_var: str, + nknots: Optional[int], + full_result: FLMMResult, + term_labels: Dict[str, str], + ) -> List[Dict[str, Any]]: + terms = [term for term in self._simple_formula_terms(formula) if term in design and term != group_var] + if not terms: + return [] + full_aic = float(full_result.aic) if full_result.aic is not None else float("nan") + rows: List[Dict[str, Any]] = [] + for term in terms: + reduced_terms = [name for name in terms if name != term] + reduced_formula = "Y.obs ~ " + " + ".join(reduced_terms) if reduced_terms else "Y.obs ~ 1" + row: Dict[str, Any] = { + "feature": term, + "label": term_labels.get(term, term), + "full_aic": full_aic, + "reduced_aic": float("nan"), + "delta_aic": float("nan"), + "mean_abs_coefficient": self._term_mean_abs_coefficient(full_result, term), + "status": "ok", + } + try: + reduced = TrialFLMM().fit( + mat, + tvec, + design, + formula_fixed=reduced_formula, + random_effects=random_eff, + group_var=group_var, + nknots_min=nknots, + num_boots=0, + ) + reduced_aic = float(reduced.aic) if reduced.aic is not None else float("nan") + row["reduced_aic"] = reduced_aic + if np.isfinite(full_aic) and np.isfinite(reduced_aic): + row["delta_aic"] = reduced_aic - full_aic + except Exception as exc: + row["status"] = f"failed: {exc}" + rows.append(row) + QtWidgets.QApplication.processEvents() + rows.sort(key=lambda item: ( + np.isfinite(item.get("delta_aic", np.nan)), + float(item.get("delta_aic", -np.inf)) if np.isfinite(item.get("delta_aic", np.nan)) else -np.inf, + float(item.get("mean_abs_coefficient", -np.inf)) if np.isfinite(item.get("mean_abs_coefficient", np.nan)) else -np.inf, + ), reverse=True) + return rows + # ------------------------------------------------------------------ # Slots # ------------------------------------------------------------------ @@ -1897,18 +2309,32 @@ def _fit_glm_catalog(self) -> None: basis_map = {"Raised cosine": "raised_cosine", "B-spline": "bspline", "FIR": "fir"} reg_map = {"Ridge": "ridge", "Lasso": "lasso", "OLS": "ols"} kernel_win = (self.spin_kernel_pre.value(), self.spin_kernel_post.value()) + basis_type = basis_map.get(self.combo_basis.currentText(), "raised_cosine") + regularization = reg_map.get(self.combo_reg.currentText(), "ridge") result = self._glm.fit( np.asarray(dataset["time"], float), np.asarray(dataset["signal"], float), dataset["predictors"], kernel_window=kernel_win, n_basis=self.spin_n_basis.value(), - basis_type=basis_map.get(self.combo_basis.currentText(), "raised_cosine"), - regularization=reg_map.get(self.combo_reg.currentText(), "ridge"), + basis_type=basis_type, + regularization=regularization, alpha=self.spin_alpha.value(), ) self._glm_result = result + self.statusMessage.emit("Calculating GLM leave-one-predictor-out contribution...", 0) + QtWidgets.QApplication.processEvents() + importance_rows = self._compute_glm_leave_one_out( + dataset, + result, + kernel_win, + basis_type, + regularization, + self.spin_alpha.value(), + ) + result.feature_importance = importance_rows + used_labels = [self._predictor_label(k) for k in result.predictor_names] dropped_predictors = dataset.get("dropped_predictors", []) or [] used_records = dataset.get("used_records", []) or [] @@ -1924,6 +2350,26 @@ def _fit_glm_catalog(self) -> None: f"Basis: {self.combo_basis.currentText()}, n={self.spin_n_basis.value()}", f"Regularization: {self.combo_reg.currentText()}, alpha={self.spin_alpha.value():.3f}", ] + stats = result.stats or {} + lines.extend([ + "", + "Fit statistics:", + f" RMSE = {stats.get('rmse', float('nan')):.5g}", + f" MAE = {stats.get('mae', float('nan')):.5g}", + f" MSE = {stats.get('mse', float('nan')):.5g}", + f" residual SD = {stats.get('residual_std', float('nan')):.5g}", + f" actual/predicted corr = {stats.get('corr', float('nan')):.5g}", + ]) + if importance_rows: + lines.extend(["", "Leave-one-predictor-out contribution (full - reduced):"]) + for row in importance_rows[:10]: + lines.append( + f" {row['label']}: delta R^2 = {row['delta_r2']:.5g}, " + f"delta MSE = {row['delta_mse']:.5g}, reduced R^2 = {row['reduced_r2']:.5g}" + ) + failed = [row for row in importance_rows if row.get("status") != "ok"] + if failed: + lines.append(f" {len(failed)} reduced fits failed; see log for details.") if dropped_predictors: lines.append(f"Dropped predictors: {', '.join(str(v) for v in dropped_predictors)}") if dropped_records: @@ -1932,8 +2378,14 @@ def _fit_glm_catalog(self) -> None: self._plot_glm_kernels(result) self._plot_glm_fit(result) + self._plot_feature_importance( + importance_rows, + value_key="delta_r2", + title="GLM leave-one-predictor-out contribution", + y_label="Drop in R^2", + ) if hasattr(self, "tabs_workspace"): - self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) + self.tabs_workspace.setCurrentWidget(self.plot_importance.parentWidget() if importance_rows else self.plot_kernel.parentWidget()) self.statusMessage.emit(f"GLM fit complete - R^2 = {result.r2:.4f}", 5000) def _fit_glm(self): @@ -2042,6 +2494,51 @@ def _plot_glm_fit(self, result: GLMResult): rw.setLabel("bottom", "Time", units="s") rw.setLabel("left", "Residual") + def _plot_feature_importance( + self, + rows: List[Dict[str, Any]], + value_key: str, + title: str, + y_label: str, + ) -> None: + if not hasattr(self, "plot_importance"): + return + pw = self.plot_importance + pw.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + pw.setTitle(title) + usable = [ + row for row in rows + if np.isfinite(float(row.get(value_key, float("nan")))) + ] + if not usable: + txt = pg.TextItem("No leave-one-out feature contribution is available.", color="#c5d2e3") + pw.addItem(txt) + txt.setPos(0, 0) + pw.setLabel("bottom", "Feature") + pw.setLabel("left", y_label) + return + + x = np.arange(len(usable), dtype=float) + vals = np.asarray([float(row.get(value_key, 0.0)) for row in usable], float) + brushes = [pg.mkBrush("#4b9df8" if val >= 0 else "#ee99a0") for val in vals] + bar = pg.BarGraphItem(x=x, height=vals, width=0.64, brushes=brushes) + pw.addItem(bar) + pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + labels = [] + for idx, row in enumerate(usable): + label = str(row.get("label", row.get("feature", idx))) + if len(label) > 18: + label = label[:15] + "..." + labels.append((idx, label)) + pw.getAxis("bottom").setTicks([labels]) + pw.setLabel("bottom", "Feature") + pw.setLabel("left", y_label) + pw.enableAutoRange() + # ------------------------------------------------------------------ # FLMM fit # ------------------------------------------------------------------ @@ -2069,13 +2566,7 @@ def _fit_flmm_group(self) -> None: if len(row_labels) != n_rows: row_labels = [f"{scope}_{i + 1}" for i in range(n_rows)] - if scope == "animal": - design, auto_terms, dropped_terms, term_labels = self._build_flmm_design(row_labels, group_var) - else: - design = {group_var: np.asarray(row_labels, dtype=object)} - auto_terms = [] - dropped_terms = [] - term_labels = {} + design, auto_terms, dropped_terms, term_labels = self._build_flmm_design(row_labels, group_var) requested_formula = self.edit_formula.text().strip() if requested_formula in {"", "Y.obs ~ 1", "Y.obs ~ group"}: @@ -2101,6 +2592,26 @@ def _fit_flmm_group(self) -> None: ) self._flmm_result = result + self.statusMessage.emit("Calculating FLMM leave-one-feature-out AIC contribution...", 0) + QtWidgets.QApplication.processEvents() + importance_rows = self._compute_flmm_leave_one_out( + mat, + tvec, + design, + formula, + random_eff, + group_var, + nknots, + result, + term_labels, + ) + result.feature_importance = importance_rows + importance_value_key = ( + "delta_aic" + if any(np.isfinite(float(row.get("delta_aic", float("nan")))) for row in importance_rows) + else "mean_abs_coefficient" + ) + summary_lines = [ result.summary_text, "", @@ -2109,6 +2620,33 @@ def _fit_flmm_group(self) -> None: f"ID variable: {group_var}", f"Formula: {formula}", ] + if result.stats: + summary_lines.extend([ + "", + "Fit statistics:", + f" AIC = {result.stats.get('aic', float('nan')):.5g}", + f" mean abs coefficient = {result.stats.get('mean_abs_coefficient', float('nan')):.5g}", + f" peak abs coefficient = {result.stats.get('peak_abs_coefficient', float('nan')):.5g}", + ]) + if importance_rows: + summary_lines.extend([ + "", + "Leave-one-feature-out contribution (reduced AIC - full AIC):", + ]) + for row in importance_rows[:10]: + delta = float(row.get("delta_aic", float("nan"))) + delta_text = f"{delta:.5g}" if np.isfinite(delta) else "n/a" + reduced = float(row.get("reduced_aic", float("nan"))) + reduced_text = f"{reduced:.5g}" if np.isfinite(reduced) else "n/a" + summary_lines.append( + f" {row['label']}: delta AIC = {delta_text}, " + f"reduced AIC = {reduced_text}, mean abs coef = {row['mean_abs_coefficient']:.5g}" + ) + failed = [row for row in importance_rows if row.get("status") != "ok"] + if failed: + summary_lines.append(f" {len(failed)} reduced FLMM fits failed; see log for details.") + if num_boots > 0: + summary_lines.append(" Leave-one-out comparison uses analytic fits; bootstrap CIs are not repeated.") if auto_terms: readable = [f"{term} = {term_labels.get(term, term)}" for term in auto_terms] summary_lines.append("Animal covariates: " + "; ".join(readable)) @@ -2116,6 +2654,14 @@ def _fit_flmm_group(self) -> None: summary_lines.append("Dropped covariates: " + ", ".join(dropped_terms)) self.txt_summary.setPlainText("\n".join(summary_lines)) self._plot_flmm_coefficients(result) + self._plot_feature_importance( + importance_rows, + value_key=importance_value_key, + title="FLMM leave-one-feature-out contribution" if importance_value_key == "delta_aic" else "FLMM coefficient contribution", + y_label="Delta AIC" if importance_value_key == "delta_aic" else "Mean abs coefficient", + ) + if hasattr(self, "tabs_workspace") and importance_rows: + self.tabs_workspace.setCurrentWidget(self.plot_importance.parentWidget()) self.statusMessage.emit("FLMM fit complete.", 5000) def _fit_flmm(self): From 991192c7c806ad5557dde2a9e1a84ec86d5db3ac Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 13:15:12 +0200 Subject: [PATCH 07/14] Improve install and user documentation --- README.md | 294 +++++++++++++++--------------------------------- docs/index.md | 285 ++++++++++++++++++++++++++++++++++++++++++++++ environment.yml | 13 ++- 3 files changed, 389 insertions(+), 203 deletions(-) create mode 100644 docs/index.md diff --git a/README.md b/README.md index 92a0a7f..725f4f2 100644 --- a/README.md +++ b/README.md @@ -1,202 +1,92 @@ -# Fiber Photometry Processing GUI - -pyBer_logo_big - - -A desktop GUI for visualizing photometry recordings, cleaning artifacts, filtering/resampling, baseline estimation, motion-correction, and exporting processed traces for downstream analysis. - -This project is designed for efficient exploratory QC (preview in the GUI) while keeping processing logic deterministic and scriptable (core functions live in `analysis_core.py`). - ---- - -## Key Features - -### Data IO -- Supports raw data in .doric, .h5 or .csv. -- Multi-channel support (e.g., analogue, DIO channels …). -- Optional alignment of analog traces to the DigitalIO timebase when a DIO is selected. - -### Artifact Handling -- Artifact detection on raw 465 using derivative thresholding (`dx`) and MAD: - - **Global MAD (dx)**: one threshold for the full trace - - **Adaptive MAD (windowed)**: windowed thresholds for nonstationary noise -- Optional **padding** around detected artifacts to remove spillover. -- Manual artifact masking by user-defined time regions. -- Masked samples are replaced via **linear interpolation** to preserve time alignment. - -### Signal Conditioning -- Low-pass filtering. -- Decimation/resampling. - -### Baseline Estimation -Baseline is computed after filtering and resampling using **pybaselines**: -- `asls`, `arpls`, `airpls` -- tunable parameters (lambda, diff order, iterations, tolerance) - -### Output Modes (7) -The GUI exposes seven explicit output definitions: - -1. **dFF (non motion corrected)** - `dFF = (signal_filtered - signal_baseline) / signal_baseline` - -2. **zscore (non motion corrected)** - `zscore(dFF_nonMC)` - -3. **dFF (motion corrected via subtraction)** - `dFF_mc = dFF_signal - dFF_ref` - where each dFF uses its own baseline. - -4. **zscore (motion corrected via subtraction)** - `zscore(dFF_signal - dFF_ref)` - -5. **zscore (subtractions)** - `zscore(dFF_signal) - zscore(dFF_ref)` - -6. **dFF (motion corrected with fitted ref)** - Fit the isosbestic/reference channel to the signal: - `fitted_ref = a * ref_filtered + b` - then compute: - `dFF = (signal_filtered - fitted_ref) / fitted_ref` - -7. **zscore (motion corrected with fitted ref)** - `zscore( (signal_filtered - fitted_ref) / fitted_ref )` - -### Reference Fitting Methods (for “fitted ref” modes) -- **OLS (recommended)**: fast and stable -- **Lasso**: sparse regression (requires `scikit-learn`) -- **RLM (HuberT)**: robust linear model via IRLS + Huber weighting (no extra dependency) - -### Export -- Export processed output to: - - CSV with configurable fields (`time` always included; raw/isobestic/output/DIO selectable) - - HDF5 with configurable raw/output/DIO/baseline datasets plus metadata -- Export field selection is saved and restored through the preprocessing configuration file. -- Drag-and-drop support for preprocessing and post-processing files. - ---- - -## Repository Structure (typical) - -- `analysis_core.py` - Processing pipeline (loading, filtering, baselines, outputs, export helpers) -- `main.py` (or similar) - PySide6 GUI entry point and UI wiring -- `requirements.yml` - Conda environment definition - ---- - -## Installation - -1. Create the environment: - ```bash - conda env create -f environment.yml -## Run - cd .\pyBer\ - python main.py - -## Usage workflow - -Open a Doric .h5 file - -Choose a channel (e.g., AIN01). - -Optionally select a DigitalIO line to overlay events. - -### QC & artifact removal - -Choose Global MAD (dx) or Adaptive MAD (windowed). - -Tune mad_k, window size, and padding. - -Add manual mask regions if needed. - -### Filtering & resampling - -Set low-pass cutoff (Hz) and filter order. - -Set a target sampling rate (Hz) for consistent downstream analysis. - -### Baseline estimation - -Choose asls, arpls, or airpls. - -Tune lambda and other parameters to avoid baseline leakage into fast transients. - -### Select output - -Pick one of the 7 output modes. - -For “fitted ref” modes, choose the fit method (OLS/Lasso/RLM-HuberT). - - -### Export - -Export CSV/H5 for analysis in Python/MATLAB/R. - ---- - -## Preprocessing: Advanced Options - -The preprocessing panel includes an **Advanced options** button with two features: - -1) **Cut out regions (NaN)** -Define start/end ranges to exclude parts of the trace from downstream analysis. Cutout regions are filled with NaN in the output and can be exported as-is. - -2) **Sections (per-section processing)** -Define multiple start/end sections and **assign per-section processing parameters**. Each section can be exported independently (one CSV/H5 per section). - -### Time window -You can optionally set a start time, end time, or both: -- Start only: process from start → end of recording -- End only: process from 0 → end -- Start + end: process that window only -- Empty: process full trace - ---- - -## Post-Processing - -### Align sources -- **DIO**: choose DIO channel, polarity (0→1 or 1→0), and align to onset/offset -- **Behavior (CSV/XLSX)**: load a behavior file and select a column (binary 0/1) - - Align to onset or offset - - **Transitions**: align to A→B transitions with a max gap threshold - -### Behavior file formats -- **CSV**: must include a time column and one or more binary behavior columns -- **Ethovision XLSX**: the loader preprocess and clean the file, the user has to select the sheet when loading - -### PSTH/Heatmap -- Heatmap and PSTH refresh automatically when alignment settings change -- Event lines are overlaid on the trace preview -- Event duration histogram appears to the right of the heatmap -- Metrics bar plot (pre vs post) appears to the right of the PSTH - -### Metrics -Choose AUC or mean z-score and define pre/post windows in seconds. - -### Export results -Export any combination of: -- Heatmap matrix -- Average PSTH + SEM -- Event times -- Event durations -- Metrics table - ---- - -## Group Mode (Multiple Animals) - -Use the **Group** tab in Post-Processing to load multiple processed files (CSV/H5). Each file should represent a single animal, and matching behavior files should share the same base name. In group mode: -- Heatmap rows represent animals (not trials) -- PSTH is averaged across animals - ---- - - - - - - - +# pyBer + +pyBer is a desktop application for fiber photometry analysis. It helps you load +Doric, HDF5, or CSV recordings, clean artifacts, preprocess traces, align signals +to behavior or DIO events, inspect PSTHs and heatmaps, detect transients, and +export results for Python, MATLAB, R, or Prism. + +The app is built for users who want an interactive workflow first, with +deterministic processing code underneath. + +![pyBer logo](https://github.com/user-attachments/assets/e5acb000-17cd-451d-9f49-4218b41519aa) + +## Quick Install + +Install Miniforge or Anaconda first, then run: + +```powershell +cd C:\Analysis\app_project\pyBer +conda env create -f environment.yml +conda activate pyBer +Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +python .\pyBer\main.py +``` + +The `fastFMM` step is only needed for the FLMM temporal modeling panel. The rest +of pyBer works without it. + +## Launch From VS Code + +1. Open the repository folder in VS Code. +2. Select the interpreter from the `pyBer` conda environment. +3. Open `pyBer/main.py`. +4. Press Run, or use: + +```powershell +conda activate pyBer +python .\pyBer\main.py +``` + +If VS Code launches the wrong Python, run `Python: Select Interpreter` and choose +the environment created from `environment.yml`. + +## What You Can Do + +- Preprocess raw photometry traces with filtering, resampling, baseline + correction, motion correction, and artifact handling. +- Detect and inspect artifacts with interpolation, cutout, local low-pass + filtering, or no-op handling. +- Export processed CSV or HDF5 files with selectable fields and metadata. +- Align processed signals to DIO, behavior states, behavior onsets, or behavior + transitions. +- Detect signal events and compare transient amplitude with baseline-prominence + normalized metrics. +- Build individual or group PSTHs, heatmaps, event duration plots, and metrics. +- Fit temporal models with continuous GLM or trial-level FLMM. +- Rank GLM/FLMM feature contribution with leave-one-feature-out summaries. + +## Documentation + +The full user guide is here: + +- [pyBer Documentation](docs/index.md) + +It includes installation, first launch, preprocessing, postprocessing, transient +detection, temporal modeling, group workflows, export, and troubleshooting. + +## Repository Layout + +- `pyBer/main.py`: application entry point. +- `pyBer/analysis_core.py`: preprocessing and signal processing backend. +- `pyBer/gui_preprocessing.py`: preprocessing panels. +- `pyBer/gui_postprocessing.py`: postprocessing, PSTH, metrics, and export panels. +- `pyBer/temporal_modeling.py`: GLM and FLMM modeling panel. +- `environment.yml`: conda environment for development and user installs. +- `pyBer.spec`: PyInstaller build configuration. + +## Build The Executable + +From an activated environment: + +```powershell +conda activate pyBer +python -m PyInstaller --noconfirm --clean pyBer.spec +``` + +The executable is written to `dist/pyBer.exe`. + +## Notes + +pyBer sets `PYTHONNOUSERSITE=1` in the environment so old packages from the user +Python folder do not interfere with the conda environment. This is important for +Qt, pyqtgraph, numpy, and rpy2 stability on Windows. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..3027b22 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,285 @@ +# pyBer Documentation + +This page is a practical guide for installing and using pyBer. It is written for +lab users who want to process recordings without reading the code first. + +## 1. Install pyBer + +### Recommended Windows install + +Install Miniforge or Anaconda, open an Anaconda/Miniforge Prompt, then run: + +```powershell +cd C:\Analysis\app_project\pyBer +conda env create -f environment.yml +conda activate pyBer +Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +python .\pyBer\main.py +``` + +The `fastFMM` command installs the R package used by the FLMM temporal modeling +panel. It can take a few minutes the first time because R downloads dependencies. + +### Update an existing environment + +If the environment already exists: + +```powershell +conda activate pyBer +conda env update -f environment.yml --prune +Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +``` + +### Test the install + +Run: + +```powershell +conda activate pyBer +python .\pyBer\main.py +``` + +The app should open with Preprocessing and Postprocessing tabs. + +## 2. Launch From VS Code + +1. Open the pyBer repository folder. +2. Press `Ctrl+Shift+P`. +3. Choose `Python: Select Interpreter`. +4. Select the `pyBer` conda environment. +5. Open `pyBer/main.py`. +6. Press Run. + +If the Run button still fails, use the terminal: + +```powershell +conda activate pyBer +python .\pyBer\main.py +``` + +## 3. Preprocessing Workflow + +Use Preprocessing when you want to clean and export photometry traces. + +1. Load one or more raw files. +2. Select the signal channel and optional reference channel. +3. Choose artifact detection and handling. +4. Set filtering and resampling. +5. Choose baseline correction. +6. Choose the output signal definition. +7. Preview the result. +8. Export CSV or HDF5. + +### Output definitions + +pyBer exposes explicit output modes so exported traces are reproducible: + +- dFF without motion correction. +- z-score without motion correction. +- dFF with motion correction by subtraction. +- z-score with motion correction by subtraction. +- z-score signal minus z-score reference. +- dFF with fitted reference. +- z-score with fitted reference. + +For fitted-reference modes, pyBer fits the reference channel to the signal before +computing dFF. The usual choice is OLS. Lasso and robust Huber fitting are also +available. + +## 4. Artifact Handling + +Artifact settings let you choose how masked windows are handled: + +- Interpolation: replace artifact samples by linear interpolation. +- Cut: keep artifact samples as NaN so downstream analysis ignores them. +- Strong local low-pass filtering: smooth only inside the artifact window. +- Do nothing: detect or mark artifacts without changing the trace. + +Use interpolation when you need continuous traces. Use cut when the artifact +window should not contribute to statistics. + +## 5. Postprocessing Workflow + +Use Postprocessing when you want to align processed traces to events or behavior. + +1. Load processed files from preprocessing. +2. Load behavior files if needed. +3. Choose the alignment source. +4. Click `Compute PSTH`. +5. Inspect the trace preview, heatmap, average PSTH, duration plot, and metrics. +6. Export matrices, event times, metrics, and figures. + +### Alignment sources + +pyBer can align to: + +- DIO onset or offset. +- Behavior onset or offset from CSV or XLSX files. +- Binary behavior state columns. +- Behavior transitions. +- Signal events detected from the processed trace. + +### Group mode + +Use Group mode when each processed file represents one animal. pyBer keeps +per-file trial matrices for temporal modeling and can also display animal-level +group summaries. + +For best GLM and FLMM results, load matching behavior files whose base names +match the processed files. + +## 6. Signal Event Analyzer + +The Signal Event Analyzer detects transients and reports metrics. Useful options: + +- Auto MAD noise thresholding for transient detection. +- Min prominence, min height, min distance, and smoothing. +- Optional detected-peak overlay on the trace. +- Optional noise trace overlay. +- Baseline-prominence normalized amplitude for comparing recordings. + +Baseline-prominence normalized amplitude is useful when recordings differ in +baseline level or noise scale. It normalizes each detected transient relative to +its local baseline/prominence context. + +## 7. Temporal Modeling + +The Temporal Modeling panel supports two approaches. + +### Continuous GLM + +Use GLM when you want to model the continuous photometry trace from event and +behavior predictors. + +Typical predictors: + +- DIO events. +- Behavior onsets. +- Behavior states. +- Numeric behavior columns. +- Signal event times. + +The GLM output includes: + +- R-squared. +- RMSE, MAE, MSE, residual SD, and actual/predicted correlation. +- Estimated kernels for each predictor. +- Actual vs predicted signal. +- Residual trace. +- Leave-one-predictor-out feature contribution. + +The leave-one-predictor-out ranking refits the model after removing each +predictor. Larger `delta R^2` means that predictor explains more of the signal. + +### Trial-level FLMM + +Use FLMM when you want trial-level functional modeling with random effects. + +Requirements: + +- R installed through the conda environment or available on the system. +- Python package `rpy2`. +- R package `fastFMM`. +- Repeated rows per subject or animal. + +The environment installs R and rpy2. Install fastFMM with: + +```powershell +conda activate pyBer +Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +``` + +The FLMM output includes: + +- Fixed-effect coefficient curves. +- Pointwise and joint confidence bands when available. +- AIC summary. +- Coefficient magnitude statistics. +- Leave-one-feature-out AIC contribution when the reduced models are estimable. + +If a reduced FLMM cannot be estimated, pyBer still reports the coefficient-based +contribution so the feature ranking remains usable. + +## 8. Export + +Preprocessing can export processed traces as: + +- CSV. +- HDF5. + +Postprocessing can export: + +- Heatmap matrix. +- Average PSTH and SEM. +- Event times. +- Event durations. +- Metrics tables. +- Group-level outputs. + +Use HDF5 when you want metadata and multiple arrays in one file. Use CSV when you +want easy loading into spreadsheets or Prism. + +## 9. Troubleshooting + +### The app does not launch from VS Code + +Make sure VS Code is using the conda environment: + +```powershell +conda activate pyBer +python .\pyBer\main.py +``` + +If this works but the Run button fails, select the interpreter again in VS Code. + +### Dark mode or Qt styling looks broken + +This is usually a mixed Python environment. Recreate the environment and keep +`PYTHONNOUSERSITE=1` enabled: + +```powershell +conda env remove -n pyBer +conda env create -f environment.yml +conda activate pyBer +``` + +### FLMM says fastFMM is unavailable + +Run: + +```powershell +conda activate pyBer +Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +``` + +Then restart pyBer. + +### FLMM says random effects cannot be estimated + +FLMM needs repeated rows per subject. In practice, each animal should have +multiple trials. If you only provide one animal-averaged row per animal, the GLM +panel is usually the better choice. + +### The heatmap looks wrong after switching Individual and Group + +Click `Compute PSTH` again after changing loaded files or behavior alignment. +pyBer stores both per-file trial matrices and group matrices, but recomputing is +the clearest way to refresh all derived views after a major setup change. + +## 10. Build A Windows Executable + +From the repository root: + +```powershell +conda activate pyBer +python -m PyInstaller --noconfirm --clean pyBer.spec +``` + +The app is written to: + +```text +dist\pyBer.exe +``` + +When building with FLMM support, make sure `fastFMM` is installed before running +PyInstaller. diff --git a/environment.yml b/environment.yml index ddefbec..895b999 100644 --- a/environment.yml +++ b/environment.yml @@ -1,3 +1,12 @@ +# Create with: +# conda env create -f environment.yml +# conda activate pyBer +# +# Optional FLMM backend: +# Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')" +# +# The R package is installed after environment creation because fastFMM is a +# CRAN package, not a conda package. name: pyBer channels: - conda-forge @@ -26,8 +35,10 @@ dependencies: - pyqtgraph>=0.13 - matplotlib>=3.8 - # FLMM backend via R fastFMM/rpy2 + # FLMM backend via R fastFMM/rpy2. + - r-base>=4.4 - rpy2>=3.6 + # Windows R/rpy2 startup uses sh/make when validating R's config. - m2-base - m2-make From acbd22c6e042d3f441d1790c243eaadb5abcb0c4 Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 13:45:06 +0200 Subject: [PATCH 08/14] Improve temporal modeling significance workflow --- pyBer/temporal_modeling.py | 441 +++++++++++++++++++++++++++++++++---- 1 file changed, 402 insertions(+), 39 deletions(-) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index de971fe..a40b62e 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -188,7 +188,7 @@ class GLMResult: r2: float coefficients: np.ndarray # raw beta vector design_matrix: np.ndarray - stats: Dict[str, float] = field(default_factory=dict) + stats: Dict[str, Any] = field(default_factory=dict) feature_importance: List[Dict[str, Any]] = field(default_factory=list) @@ -539,7 +539,32 @@ def fit( if group_var not in design: raise ValueError(f"FLMM design is missing group variable '{group_var}'.") + design = dict(design) r_df_vars = {} + group_var_model = group_var + fallback_grouping = "" + group_vals = np.asarray(design[group_var]).astype(str) + unique_groups = np.unique(group_vals) + has_repeated_groups = 1 < unique_groups.size < n_trials + if not has_repeated_groups: + if n_trials < 4: + raise ValueError( + "FLMM needs at least 4 rows when the selected grouping variable has fewer than two repeated levels. " + "Compute per-file/per-trial PSTH rows or choose a grouping variable with repeated samples." + ) + block_name = "pyber_block" + i = 2 + while block_name in design: + block_name = f"pyber_block_{i}" + i += 1 + n_blocks = min(4, max(2, n_trials // 2)) + design[block_name] = np.asarray([f"block_{(idx % n_blocks) + 1}" for idx in range(n_trials)], dtype=object) + group_var_model = block_name + fallback_grouping = ( + f"Grouping variable '{group_var}' had {unique_groups.size} level(s) across {n_trials} rows. " + f"Used exploratory block grouping '{group_var_model}' with {n_blocks} levels so fastFMM can fit." + ) + for col_name, col_vals in design.items(): col_vals = np.asarray(col_vals) if col_vals.size != n_trials: @@ -558,20 +583,19 @@ def fit( r_df = R(".__pyber_flmm_df$Y.obs <- I(.__pyber_flmm_y); .__pyber_flmm_df") formula_text = str(formula_fixed or "Y.obs ~ 1").strip() or "Y.obs ~ 1" - group_vals = np.asarray(design[group_var]).astype(str) - has_repeated_groups = np.unique(group_vals).size < group_vals.size + if group_var_model != group_var: + formula_text = re.sub( + rf"\|\s*`?{re.escape(group_var)}`?\s*\)", + f"| {group_var_model})", + formula_text, + ) if "|" not in formula_text: rand = str(random_effects or "~1").strip() rand_rhs = rand.split("~", 1)[1].strip() if "~" in rand else rand if rand_rhs.lower() in {"", "0", "none", "fixed"}: rand_rhs = "" - if not has_repeated_groups: - raise ValueError( - "FLMM requires repeated rows per subject for random effects. " - "Compute PSTH from per-file trials instead of animal-averaged rows." - ) if rand_rhs: - formula_text = f"{formula_text} + ({rand_rhs} | {group_var})" + formula_text = f"{formula_text} + ({rand_rhs} | {group_var_model})" # Call fui() fastFMM = importr("fastFMM") @@ -580,7 +604,7 @@ def fit( "data": r_df, "parallel": ro.BoolVector([parallel]), "silent": ro.BoolVector([True]), - "subj_id": ro.StrVector([group_var]), + "subj_id": ro.StrVector([group_var_model]), "override_zero_var": ro.BoolVector([True]), } if nknots_min is not None: @@ -691,6 +715,8 @@ def fit( continue summary_parts = [f"FLMM fit: {len(term_names)} terms, {n_trials} trials, {n_time} timepoints"] + if fallback_grouping: + summary_parts.append(f"Note: {fallback_grouping}") if aic_val is not None: summary_parts.append(f"AIC = {aic_val:.1f}") coeff_abs_peaks: List[float] = [] @@ -711,6 +737,10 @@ def fit( "aic": float(aic_val) if aic_val is not None else float("nan"), "mean_abs_coefficient": float(np.nanmean(coeff_abs_means)) if coeff_abs_means else float("nan"), "peak_abs_coefficient": float(np.nanmax(coeff_abs_peaks)) if coeff_abs_peaks else float("nan"), + "formula": formula_text, + "group_var": group_var_model, + "requested_group_var": group_var, + "fallback_grouping": fallback_grouping, } except Exception as exc: @@ -799,6 +829,18 @@ def fit( border: 0; width: 24px; } +QProgressBar { + color: #e9f0fb; + background: #0f1724; + border: 1px solid #314963; + border-radius: 6px; + min-height: 18px; + text-align: center; +} +QProgressBar::chunk { + background: #2d8cff; + border-radius: 5px; +} QListWidget, QTextEdit { color: #e6edf8; background: #0d1420; @@ -866,6 +908,9 @@ class TemporalModelingWidget(QtWidgets.QWidget): def __init__(self, parent: Optional[QtWidgets.QWidget] = None): super().__init__(parent) + self._settings = QtCore.QSettings("BelloneLab", "pyBer") + self._loading_settings = True + self._saved_predictor_keys: List[str] = [] self._glm = ContinuousGLM() self._flmm = TrialFLMM() self._glm_result: Optional[GLMResult] = None @@ -889,7 +934,9 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._group_mode: bool = False self._build_compact_ui() + self._load_temporal_settings() self._connect_signals() + self._on_model_type_changed(self.combo_model_type.currentIndex()) # ------------------------------------------------------------------ # UI construction @@ -958,6 +1005,13 @@ def _build_ui(self): self.spin_kernel_post.setSuffix(" s") gl.addRow("Kernel post:", self.spin_kernel_post) + self.spin_glm_bootstrap = QtWidgets.QSpinBox() + self.spin_glm_bootstrap.setRange(0, 2000) + self.spin_glm_bootstrap.setValue(100) + self.spin_glm_bootstrap.setSpecialValueText("off") + self.spin_glm_bootstrap.setToolTip("Circular-shift bootstraps for leave-one-out contribution p-values.") + gl.addRow("Shift bootstraps:", self.spin_glm_bootstrap) + root.addWidget(self.grp_glm) # ---- FLMM settings ---- @@ -1109,6 +1163,12 @@ def _build_compact_ui(self): self.btn_fit.setProperty("class", "primary") self.btn_fit.setMinimumWidth(120) h.addWidget(self.btn_fit) + + self.progress_model = QtWidgets.QProgressBar() + self.progress_model.setMinimumWidth(260) + self.progress_model.setMaximumWidth(360) + self.progress_model.setVisible(False) + h.addWidget(self.progress_model) root.addWidget(header) split = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) @@ -1208,6 +1268,13 @@ def _build_model_page(self): self.spin_kernel_post.setDecimals(1) self.spin_kernel_post.setSuffix(" s") gl.addRow("Kernel post", self.spin_kernel_post) + + self.spin_glm_bootstrap = QtWidgets.QSpinBox() + self.spin_glm_bootstrap.setRange(0, 2000) + self.spin_glm_bootstrap.setValue(100) + self.spin_glm_bootstrap.setSpecialValueText("off") + self.spin_glm_bootstrap.setToolTip("Circular-shift bootstraps for leave-one-out contribution p-values.") + gl.addRow("Shift bootstraps", self.spin_glm_bootstrap) lay.addWidget(self.grp_glm) self.grp_flmm = QtWidgets.QGroupBox("FLMM Settings") @@ -1376,6 +1443,42 @@ def _style_plot(self, plot: pg.PlotWidget) -> None: pi.getAxis("left").setTextPen(pg.mkPen("#c5d2e3")) pi.titleLabel.item.setDefaultTextColor(QtGui.QColor("#d7e0ee")) + def _progress_start(self, label: str, maximum: int = 0) -> None: + if not hasattr(self, "progress_model"): + return + self.progress_model.setVisible(True) + if maximum <= 0: + self.progress_model.setRange(0, 0) + self.progress_model.setFormat(label) + else: + self.progress_model.setRange(0, maximum) + self.progress_model.setValue(0) + self.progress_model.setFormat(f"{label} %p%") + QtWidgets.QApplication.processEvents() + + def _progress_update(self, value: int, label: Optional[str] = None) -> None: + if not hasattr(self, "progress_model") or not self.progress_model.isVisible(): + return + if label: + self.progress_model.setFormat(f"{label} %p%" if self.progress_model.maximum() > 0 else label) + if self.progress_model.maximum() > 0: + self.progress_model.setValue(max(0, min(int(value), self.progress_model.maximum()))) + QtWidgets.QApplication.processEvents() + + def _progress_finish(self) -> None: + if not hasattr(self, "progress_model"): + return + self.progress_model.setVisible(False) + self.progress_model.setRange(0, 100) + self.progress_model.setValue(0) + QtWidgets.QApplication.processEvents() + + def _set_fit_enabled(self, enabled: bool) -> None: + for attr in ("btn_fit", "btn_fit_side"): + btn = getattr(self, attr, None) + if btn is not None: + btn.setEnabled(enabled) + # ------------------------------------------------------------------ # Signal wiring # ------------------------------------------------------------------ @@ -1385,6 +1488,71 @@ def _connect_signals(self): self.btn_fit.clicked.connect(self._on_fit_clicked) self.btn_add_predictor.clicked.connect(self._on_add_predictor) self.btn_remove_predictor.clicked.connect(self._on_remove_predictor) + for widget in ( + self.combo_basis, + self.combo_reg, + self.spin_n_basis, + self.spin_alpha, + self.spin_kernel_pre, + self.spin_kernel_post, + self.spin_glm_bootstrap, + self.spin_nknots, + self.spin_boots, + ): + signal = getattr(widget, "currentIndexChanged", None) or getattr(widget, "valueChanged", None) + if signal is not None: + signal.connect(lambda *_: self._save_temporal_settings()) + for edit in (self.edit_formula, self.edit_random, self.edit_group_var): + edit.editingFinished.connect(self._save_temporal_settings) + + def _load_temporal_settings(self) -> None: + self._loading_settings = True + try: + prefix = "temporal_modeling/" + self.combo_model_type.setCurrentIndex(int(self._settings.value(prefix + "model_type", 0))) + for combo, key in ((self.combo_basis, "basis"), (self.combo_reg, "regularization")): + text = str(self._settings.value(prefix + key, "") or "") + idx = combo.findText(text) + if idx >= 0: + combo.setCurrentIndex(idx) + self.spin_n_basis.setValue(int(self._settings.value(prefix + "n_basis", self.spin_n_basis.value()))) + self.spin_alpha.setValue(float(self._settings.value(prefix + "alpha", self.spin_alpha.value()))) + self.spin_kernel_pre.setValue(float(self._settings.value(prefix + "kernel_pre", self.spin_kernel_pre.value()))) + self.spin_kernel_post.setValue(float(self._settings.value(prefix + "kernel_post", self.spin_kernel_post.value()))) + self.spin_glm_bootstrap.setValue(int(self._settings.value(prefix + "glm_shift_bootstraps", self.spin_glm_bootstrap.value()))) + self.edit_formula.setText(str(self._settings.value(prefix + "flmm_formula", self.edit_formula.text()) or "Y.obs ~ 1")) + self.edit_random.setText(str(self._settings.value(prefix + "flmm_random", self.edit_random.text()) or "~1")) + self.edit_group_var.setText(str(self._settings.value(prefix + "flmm_group_var", self.edit_group_var.text()) or "subject")) + self.spin_nknots.setValue(int(self._settings.value(prefix + "flmm_nknots", self.spin_nknots.value()))) + self.spin_boots.setValue(int(self._settings.value(prefix + "flmm_boots", self.spin_boots.value()))) + raw_predictors = str(self._settings.value(prefix + "predictor_keys", "") or "") + self._saved_predictor_keys = [key for key in raw_predictors.split("\n") if key.strip()] + finally: + self._loading_settings = False + + def _save_temporal_settings(self) -> None: + if getattr(self, "_loading_settings", False): + return + prefix = "temporal_modeling/" + self._settings.setValue(prefix + "model_type", self.combo_model_type.currentIndex()) + self._settings.setValue(prefix + "basis", self.combo_basis.currentText()) + self._settings.setValue(prefix + "regularization", self.combo_reg.currentText()) + self._settings.setValue(prefix + "n_basis", self.spin_n_basis.value()) + self._settings.setValue(prefix + "alpha", self.spin_alpha.value()) + self._settings.setValue(prefix + "kernel_pre", self.spin_kernel_pre.value()) + self._settings.setValue(prefix + "kernel_post", self.spin_kernel_post.value()) + self._settings.setValue(prefix + "glm_shift_bootstraps", self.spin_glm_bootstrap.value()) + self._settings.setValue(prefix + "flmm_formula", self.edit_formula.text().strip()) + self._settings.setValue(prefix + "flmm_random", self.edit_random.text().strip()) + self._settings.setValue(prefix + "flmm_group_var", self.edit_group_var.text().strip()) + self._settings.setValue(prefix + "flmm_nknots", self.spin_nknots.value()) + self._settings.setValue(prefix + "flmm_boots", self.spin_boots.value()) + predictors = self._selected_predictor_keys() if hasattr(self, "list_predictors") else self._saved_predictor_keys + if predictors or getattr(self, "_predictor_catalog", None): + self._saved_predictor_keys = list(predictors) + else: + predictors = list(self._saved_predictor_keys) + self._settings.setValue(prefix + "predictor_keys", "\n".join(predictors)) # ------------------------------------------------------------------ # Public API — called by PostProcessingPanel @@ -1465,6 +1633,8 @@ def _predictor_label(self, key: str) -> str: entry = self._predictor_catalog.get(str(key), {}) label = str(entry.get("label", "") or "").strip() if label: + if label.startswith("Numeric column:"): + return label.split(":", 1)[1].strip() return label if key == "events": return "PSTH alignment events" @@ -1477,9 +1647,19 @@ def _predictor_label(self, key: str) -> str: if key.startswith("behavior_state::"): return f"Behavior state: {key.split('::', 1)[1]}" if key.startswith("behavior_cont::"): - return f"Numeric column: {key.split('::', 1)[1]}" + return key.split("::", 1)[1] return str(key) + @staticmethod + def _compact_feature_label(label: object, max_len: int = 42) -> str: + text = str(label or "").strip() + if text.startswith("Numeric column:"): + text = text.split(":", 1)[1].strip() + text = re.sub(r"\s+", " ", text) + if len(text) > max_len: + return text[:max_len - 3] + "..." + return text + def _selected_predictor_keys(self) -> List[str]: keys: List[str] = [] for i in range(self.list_predictors.count()): @@ -1580,19 +1760,22 @@ def _refresh_predictor_catalog(self) -> None: catalog[f"behavior_cont::{name}"] = { "kind": "continuous", "name": name, - "label": f"Numeric column: {name}", + "label": name, } previous_keys = self._selected_predictor_keys() if hasattr(self, "list_predictors") else [] + restore_keys = previous_keys or [key for key in self._saved_predictor_keys if key in catalog] self._predictor_catalog = catalog self._refresh_predictor_combo() if hasattr(self, "list_predictors"): self.list_predictors.clear() - for key in previous_keys: + for key in restore_keys: if key in catalog: self._add_predictor_item(key) if self.list_predictors.count() == 0 and "events" in catalog: self._add_predictor_item("events") + if self.list_predictors.count() > 0 or previous_keys: + self._save_temporal_settings() def _behavior_source_for_proc(self, proc: Any) -> Optional[Dict[str, Any]]: if not self._behavior_sources: @@ -1791,16 +1974,21 @@ def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: vec_parts: Dict[str, List[np.ndarray]] = {key: [] for key in selected} pred_types: Dict[str, str] = {} used_records: List[str] = [] + segment_slices: List[Tuple[int, int]] = [] offset = 0.0 + cursor = 0 for seg_idx, (file_id, t, y, proc) in enumerate(segments): dt = float(np.nanmedian(np.diff(t))) if not np.isfinite(dt) or dt <= 0: dropped_records.append(file_id) continue t_shift = (t - float(t[0])) + offset + start = cursor time_parts.append(t_shift) signal_parts.append(y.astype(float, copy=True)) used_records.append(file_id) + cursor += int(t.size) + segment_slices.append((start, cursor)) for key in selected: vec, ptype = self._predictor_vector_for_proc(key, proc, t) vec = np.asarray(vec, float) @@ -1817,6 +2005,7 @@ def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: signal_parts.append(np.full(pad_n, np.nan, float)) for key in selected: vec_parts[key].append(np.zeros(pad_n, float)) + cursor += int(pad_n) offset = float(pad_t[-1] + dt) if not time_parts: @@ -1859,6 +2048,7 @@ def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: "dropped_records": dropped_records, "dropped_predictors": dropped_predictors, "valid_samples": int(np.sum(valid_signal)), + "segment_slices": segment_slices, } def _proc_for_file_id(self, file_id: str) -> Optional[Any]: @@ -2081,6 +2271,116 @@ def _intercept_only_fit_stats(signal: np.ndarray) -> Dict[str, float]: mse = float(np.nanmean(residuals ** 2)) if y.size else float("nan") return {"r2": 1.0 - ss_res / max(ss_tot, 1e-12), "mse": mse} + @staticmethod + def _shift_vector_by_segment( + values: np.ndarray, + segment_slices: List[Tuple[int, int]], + rng: np.random.Generator, + ) -> np.ndarray: + vec = np.asarray(values, float) + shifted = np.zeros_like(vec) + slices = segment_slices or [(0, int(vec.size))] + for lo, hi in slices: + lo = max(0, int(lo)) + hi = min(int(hi), int(vec.size)) + n = hi - lo + if n <= 1: + continue + shift = int(rng.integers(1, n)) + shifted[lo:hi] = np.roll(vec[lo:hi], shift) + shifted[~np.isfinite(shifted)] = 0.0 + return shifted + + def _compute_glm_shift_bootstrap_significance( + self, + dataset: Dict[str, Any], + rows: List[Dict[str, Any]], + kernel_window: Tuple[float, float], + basis_type: str, + regularization: str, + alpha: float, + n_boot: int, + ) -> None: + if n_boot <= 0 or not rows: + for row in rows: + row["p_value"] = float("nan") + row["significant"] = False + row["bootstrap_n"] = 0 + return + + predictors = dict(dataset.get("predictors", {}) or {}) + time = np.asarray(dataset["time"], float) + signal = np.asarray(dataset["signal"], float) + segment_slices = list(dataset.get("segment_slices", []) or []) + rng = np.random.default_rng() + total = max(1, int(n_boot) * len(rows)) + self._progress_start("GLM circular-shift test", total) + step = 0 + for row in rows: + feature = str(row.get("feature", "")) + obs_delta = float(row.get("delta_r2", float("nan"))) + reduced_r2 = float(row.get("reduced_r2", float("nan"))) + spec = predictors.get(feature) + if ( + not feature + or spec is None + or not np.isfinite(obs_delta) + or obs_delta <= 0 + or not np.isfinite(reduced_r2) + ): + row["p_value"] = 1.0 + row["significant"] = False + row["bootstrap_n"] = 0 + step += int(n_boot) + self._progress_update(step, "GLM circular-shift test") + continue + + base_vec = ContinuousGLM._predictor_vector(time, spec) + null_deltas: List[float] = [] + for _ in range(int(n_boot)): + shifted_predictors = dict(predictors) + shifted_predictors[feature] = { + "kind": "vector", + "values": self._shift_vector_by_segment(base_vec, segment_slices, rng), + } + try: + shifted = ContinuousGLM().fit( + time, + signal, + shifted_predictors, + kernel_window=kernel_window, + n_basis=self.spin_n_basis.value(), + basis_type=basis_type, + regularization=regularization, + alpha=alpha, + ) + null_delta = float(shifted.r2 - reduced_r2) + if np.isfinite(null_delta): + null_deltas.append(null_delta) + except Exception as exc: + _LOG.debug("Circular-shift GLM bootstrap failed for %s: %s", feature, exc) + step += 1 + if step == total or step % 5 == 0: + self._progress_update(step, "GLM circular-shift test") + + null_arr = np.asarray(null_deltas, float) + if null_arr.size: + p_value = float((1 + np.sum(null_arr >= obs_delta)) / (null_arr.size + 1)) + row["p_value"] = p_value + row["significant"] = bool(p_value < 0.05 and obs_delta > 0) + row["bootstrap_n"] = int(null_arr.size) + row["null_delta_mean"] = float(np.nanmean(null_arr)) + row["null_delta_q95"] = float(np.nanpercentile(null_arr, 95)) + else: + row["p_value"] = float("nan") + row["significant"] = False + row["bootstrap_n"] = 0 + rows.sort(key=lambda item: ( + bool(item.get("significant", False)), + np.isfinite(item.get("delta_r2", np.nan)), + float(item.get("delta_r2", -np.inf)) if np.isfinite(item.get("delta_r2", np.nan)) else -np.inf, + ), reverse=True) + def _compute_glm_leave_one_out( self, dataset: Dict[str, Any], @@ -2097,6 +2397,7 @@ def _compute_glm_leave_one_out( rows: List[Dict[str, Any]] = [] time = np.asarray(dataset["time"], float) signal = np.asarray(dataset["signal"], float) + self._progress_start("GLM leave-one-out", len(result.predictor_names)) for pred_name in result.predictor_names: row: Dict[str, Any] = { "feature": pred_name, @@ -2142,7 +2443,7 @@ def _compute_glm_leave_one_out( except Exception as exc: row["status"] = f"failed: {exc}" rows.append(row) - QtWidgets.QApplication.processEvents() + self._progress_update(len(rows), "GLM leave-one-out") rows.sort(key=lambda item: ( np.isfinite(item.get("delta_r2", np.nan)), float(item.get("delta_r2", -np.inf)) if np.isfinite(item.get("delta_r2", np.nan)) else -np.inf, @@ -2195,6 +2496,7 @@ def _compute_flmm_leave_one_out( return [] full_aic = float(full_result.aic) if full_result.aic is not None else float("nan") rows: List[Dict[str, Any]] = [] + self._progress_start("FLMM leave-one-out", len(terms)) for term in terms: reduced_terms = [name for name in terms if name != term] reduced_formula = "Y.obs ~ " + " + ".join(reduced_terms) if reduced_terms else "Y.obs ~ 1" @@ -2225,7 +2527,7 @@ def _compute_flmm_leave_one_out( except Exception as exc: row["status"] = f"failed: {exc}" rows.append(row) - QtWidgets.QApplication.processEvents() + self._progress_update(len(rows), "FLMM leave-one-out") rows.sort(key=lambda item: ( np.isfinite(item.get("delta_aic", np.nan)), float(item.get("delta_aic", -np.inf)) if np.isfinite(item.get("delta_aic", np.nan)) else -np.inf, @@ -2261,6 +2563,7 @@ def _on_model_type_changed(self, index: int): "then install rpy2 (pip install rpy2)." ) self.lbl_flmm_status.setStyleSheet("color: #f5a97f;") + self._save_temporal_settings() def _on_add_predictor(self): key = "" @@ -2271,15 +2574,18 @@ def _on_add_predictor(self): self.statusMessage.emit("No predictor is available yet. Load or compute behavior/events first.", 5000) return if self._add_predictor_item(key): + self._save_temporal_settings() self.statusMessage.emit(f"Added predictor: {self._predictor_label(key)}", 3000) def _on_remove_predictor(self): sel = self.list_predictors.currentRow() if sel >= 0: self.list_predictors.takeItem(sel) + self._save_temporal_settings() def _on_fit_clicked(self): model_idx = self.combo_model_type.currentIndex() + self._set_fit_enabled(False) try: if model_idx == 0: self._fit_glm() @@ -2289,12 +2595,16 @@ def _on_fit_clicked(self): _LOG.error("Temporal modeling fit failed: %s\n%s", exc, traceback.format_exc()) self.txt_summary.setPlainText(f"Error: {exc}") self.statusMessage.emit(f"Temporal model fit failed: {exc}", 8000) + finally: + self._set_fit_enabled(True) + self._progress_finish() # ------------------------------------------------------------------ # GLM fit # ------------------------------------------------------------------ def _fit_glm_catalog(self) -> None: + self._save_temporal_settings() dataset = self._build_glm_dataset_from_selected_predictors() if "error" in dataset: msg = str(dataset.get("error", "Could not build GLM dataset.")) @@ -2311,6 +2621,7 @@ def _fit_glm_catalog(self) -> None: kernel_win = (self.spin_kernel_pre.value(), self.spin_kernel_post.value()) basis_type = basis_map.get(self.combo_basis.currentText(), "raised_cosine") regularization = reg_map.get(self.combo_reg.currentText(), "ridge") + self._progress_start("Fitting GLM", 0) result = self._glm.fit( np.asarray(dataset["time"], float), np.asarray(dataset["signal"], float), @@ -2333,6 +2644,16 @@ def _fit_glm_catalog(self) -> None: regularization, self.spin_alpha.value(), ) + n_boot = int(self.spin_glm_bootstrap.value()) + self._compute_glm_shift_bootstrap_significance( + dataset, + importance_rows, + kernel_win, + basis_type, + regularization, + self.spin_alpha.value(), + n_boot, + ) result.feature_importance = importance_rows used_labels = [self._predictor_label(k) for k in result.predictor_names] @@ -2349,6 +2670,7 @@ def _fit_glm_catalog(self) -> None: f"Predictors: {', '.join(used_labels)}", f"Basis: {self.combo_basis.currentText()}, n={self.spin_n_basis.value()}", f"Regularization: {self.combo_reg.currentText()}, alpha={self.spin_alpha.value():.3f}", + f"Circular-shift bootstraps: {n_boot if n_boot > 0 else 'off'}", ] stats = result.stats or {} lines.extend([ @@ -2363,10 +2685,17 @@ def _fit_glm_catalog(self) -> None: if importance_rows: lines.extend(["", "Leave-one-predictor-out contribution (full - reduced):"]) for row in importance_rows[:10]: + p_value = float(row.get("p_value", float("nan"))) + p_text = f", p = {p_value:.4g}" if np.isfinite(p_value) else "" + sig_text = " [significant]" if row.get("significant", False) else "" lines.append( f" {row['label']}: delta R^2 = {row['delta_r2']:.5g}, " f"delta MSE = {row['delta_mse']:.5g}, reduced R^2 = {row['reduced_r2']:.5g}" + f"{p_text}{sig_text}" ) + if n_boot > 0: + significant = [row for row in importance_rows if row.get("significant", False)] + lines.append(f" Significant predictors at p < 0.05: {len(significant)}") failed = [row for row in importance_rows if row.get("status") != "ok"] if failed: lines.append(f" {len(failed)} reduced fits failed; see log for details.") @@ -2518,32 +2847,49 @@ def _plot_feature_importance( txt = pg.TextItem("No leave-one-out feature contribution is available.", color="#c5d2e3") pw.addItem(txt) txt.setPos(0, 0) - pw.setLabel("bottom", "Feature") - pw.setLabel("left", y_label) + pw.setLabel("bottom", y_label) + pw.setLabel("left", "Feature") return - x = np.arange(len(usable), dtype=float) + usable = usable[:25] vals = np.asarray([float(row.get(value_key, 0.0)) for row in usable], float) - brushes = [pg.mkBrush("#4b9df8" if val >= 0 else "#ee99a0") for val in vals] - bar = pg.BarGraphItem(x=x, height=vals, width=0.64, brushes=brushes) + y_pos = np.arange(len(usable), dtype=float)[::-1] + brushes = [] + for row, val in zip(usable, vals): + if row.get("significant", False): + brushes.append(pg.mkBrush("#f5c542")) + else: + brushes.append(pg.mkBrush("#4b9df8" if val >= 0 else "#ee99a0")) + bar = pg.BarGraphItem(x0=np.zeros_like(vals), x1=vals, y=y_pos, height=0.62, brushes=brushes) pw.addItem(bar) - pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + pw.addLine(x=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) labels = [] - for idx, row in enumerate(usable): - label = str(row.get("label", row.get("feature", idx))) - if len(label) > 18: - label = label[:15] + "..." - labels.append((idx, label)) - pw.getAxis("bottom").setTicks([labels]) - pw.setLabel("bottom", "Feature") - pw.setLabel("left", y_label) - pw.enableAutoRange() + max_abs = max(float(np.nanmax(np.abs(vals))) if vals.size else 1.0, 1e-9) + label_offset = max_abs * 0.025 + for pos, row, val in zip(y_pos, usable, vals): + label = self._compact_feature_label(row.get("label", row.get("feature", "")), 46) + labels.append((float(pos), label)) + p_value = float(row.get("p_value", float("nan"))) + if row.get("significant", False) and np.isfinite(p_value): + p_txt = pg.TextItem(f"p={p_value:.3g}", color="#f5c542", anchor=(0.0, 0.5)) + pw.addItem(p_txt) + p_txt.setPos(float(val) + label_offset, float(pos)) + pw.getAxis("left").setTicks([labels]) + pw.getAxis("bottom").setTicks(None) + pw.setLabel("bottom", y_label) + pw.setLabel("left", "Feature") + lo = min(0.0, float(np.nanmin(vals)) if vals.size else 0.0) + hi = max(0.0, float(np.nanmax(vals)) if vals.size else 1.0) + pad = max((hi - lo) * 0.15, max_abs * 0.12, 1e-6) + pw.setXRange(lo - pad, hi + pad, padding=0.0) + pw.setYRange(-1, len(usable), padding=0.02) # ------------------------------------------------------------------ # FLMM fit # ------------------------------------------------------------------ def _fit_flmm_group(self) -> None: + self._save_temporal_settings() if not self._flmm.available: self.statusMessage.emit( "R + fastFMM not available. Please install R, rpy2, and the fastFMM R package.", 8000 @@ -2569,17 +2915,27 @@ def _fit_flmm_group(self) -> None: design, auto_terms, dropped_terms, term_labels = self._build_flmm_design(row_labels, group_var) requested_formula = self.edit_formula.text().strip() - if requested_formula in {"", "Y.obs ~ 1", "Y.obs ~ group"}: - formula = "Y.obs ~ " + " + ".join(auto_terms) if auto_terms else "Y.obs ~ 1" - if requested_formula != formula: - self.edit_formula.setText(formula) - else: - formula = requested_formula + auto_formula = "Y.obs ~ " + " + ".join(auto_terms) if auto_terms else "Y.obs ~ 1" + requested_terms = [ + term for term in self._simple_formula_terms(requested_formula) + if "|" not in term and "(" not in term and ")" not in term + ] + missing_terms = [term for term in requested_terms if term not in design] + use_auto = ( + requested_formula in {"", "Y.obs ~ 1", "Y.obs ~ group"} + or "~" not in requested_formula + or not requested_formula.lstrip().startswith("Y.obs") + or bool(missing_terms) + ) + formula = auto_formula if use_auto else requested_formula + if requested_formula != formula: + self.edit_formula.setText(formula) random_eff = self.edit_random.text().strip() or "~1" nknots = self.spin_nknots.value() if self.spin_nknots.value() > 0 else None num_boots = self.spin_boots.value() self.statusMessage.emit("Fitting FLMM via fastFMM - this may take a while...", 0) + self._progress_start("Fitting FLMM via fastFMM", 0) QtWidgets.QApplication.processEvents() result = self._flmm.fit( @@ -2612,14 +2968,21 @@ def _fit_flmm_group(self) -> None: else "mean_abs_coefficient" ) + effective_group_var = str(result.stats.get("group_var", group_var)) if result.stats else group_var + effective_formula = str(result.stats.get("formula", formula)) if result.stats else formula + fallback_grouping = str(result.stats.get("fallback_grouping", "")) if result.stats else "" summary_lines = [ result.summary_text, "", f"Scope: {'animal/group rows' if scope == 'animal' else 'trial rows'}", f"Rows: {n_rows}", - f"ID variable: {group_var}", - f"Formula: {formula}", + f"ID variable: {effective_group_var}", + f"Formula: {effective_formula}", ] + if missing_terms: + summary_lines.append(f"Auto formula used because saved terms were unavailable: {', '.join(missing_terms)}") + if fallback_grouping: + summary_lines.append(f"Grouping note: {fallback_grouping}") if result.stats: summary_lines.extend([ "", @@ -2649,7 +3012,7 @@ def _fit_flmm_group(self) -> None: summary_lines.append(" Leave-one-out comparison uses analytic fits; bootstrap CIs are not repeated.") if auto_terms: readable = [f"{term} = {term_labels.get(term, term)}" for term in auto_terms] - summary_lines.append("Animal covariates: " + "; ".join(readable)) + summary_lines.append("Covariates: " + "; ".join(readable)) if dropped_terms: summary_lines.append("Dropped covariates: " + ", ".join(dropped_terms)) self.txt_summary.setPlainText("\n".join(summary_lines)) From a12e7f6cdb6f708db242adb4c6081531b5e30b90 Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 14:27:51 +0200 Subject: [PATCH 09/14] Speed up GLM bootstrap and filter kernels --- pyBer/temporal_modeling.py | 214 ++++++++++++++++++++++++++++++------- 1 file changed, 178 insertions(+), 36 deletions(-) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index a40b62e..33f7120 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -13,6 +13,8 @@ """ from __future__ import annotations +import concurrent.futures +import hashlib import logging import os import re @@ -911,6 +913,8 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._settings = QtCore.QSettings("BelloneLab", "pyBer") self._loading_settings = True self._saved_predictor_keys: List[str] = [] + self._kernel_visible: Dict[str, bool] = {} + self._kernel_filter_guard = False self._glm = ContinuousGLM() self._flmm = TrialFLMM() self._glm_result: Optional[GLMResult] = None @@ -1012,6 +1016,13 @@ def _build_ui(self): self.spin_glm_bootstrap.setToolTip("Circular-shift bootstraps for leave-one-out contribution p-values.") gl.addRow("Shift bootstraps:", self.spin_glm_bootstrap) + self.spin_glm_jobs = QtWidgets.QSpinBox() + max_jobs = max(1, os.cpu_count() or 1) + self.spin_glm_jobs.setRange(1, max_jobs) + self.spin_glm_jobs.setValue(min(4, max_jobs)) + self.spin_glm_jobs.setToolTip("Parallel jobs used for circular-shift bootstrap fits.") + gl.addRow("Bootstrap jobs:", self.spin_glm_jobs) + root.addWidget(self.grp_glm) # ---- FLMM settings ---- @@ -1275,6 +1286,13 @@ def _build_model_page(self): self.spin_glm_bootstrap.setSpecialValueText("off") self.spin_glm_bootstrap.setToolTip("Circular-shift bootstraps for leave-one-out contribution p-values.") gl.addRow("Shift bootstraps", self.spin_glm_bootstrap) + + self.spin_glm_jobs = QtWidgets.QSpinBox() + max_jobs = max(1, os.cpu_count() or 1) + self.spin_glm_jobs.setRange(1, max_jobs) + self.spin_glm_jobs.setValue(min(4, max_jobs)) + self.spin_glm_jobs.setToolTip("Parallel jobs used for circular-shift bootstrap fits.") + gl.addRow("Bootstrap jobs", self.spin_glm_jobs) lay.addWidget(self.grp_glm) self.grp_flmm = QtWidgets.QGroupBox("FLMM Settings") @@ -1379,6 +1397,24 @@ def _build_workspace_pages(self): kernel_page = QtWidgets.QWidget() kernel_lay = QtWidgets.QVBoxLayout(kernel_page) kernel_lay.setContentsMargins(10, 10, 10, 10) + filter_row = QtWidgets.QHBoxLayout() + filter_row.setSpacing(8) + lbl_filter = QtWidgets.QLabel("Show kernels") + lbl_filter.setProperty("class", "muted") + filter_row.addWidget(lbl_filter) + self.btn_kernel_all = QtWidgets.QPushButton("All") + self.btn_kernel_none = QtWidgets.QPushButton("None") + filter_row.addWidget(self.btn_kernel_all) + filter_row.addWidget(self.btn_kernel_none) + self.list_kernel_filter = QtWidgets.QListWidget() + self.list_kernel_filter.setMaximumHeight(86) + self.list_kernel_filter.setFlow(QtWidgets.QListView.Flow.LeftToRight) + self.list_kernel_filter.setWrapping(True) + self.list_kernel_filter.setResizeMode(QtWidgets.QListView.ResizeMode.Adjust) + self.list_kernel_filter.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.list_kernel_filter.setVerticalScrollMode(QtWidgets.QAbstractItemView.ScrollMode.ScrollPerPixel) + filter_row.addWidget(self.list_kernel_filter, 1) + kernel_lay.addLayout(filter_row) self.plot_kernel = pg.PlotWidget(title="Estimated kernels") self._style_plot(self.plot_kernel) kernel_lay.addWidget(self.plot_kernel, 1) @@ -1496,6 +1532,7 @@ def _connect_signals(self): self.spin_kernel_pre, self.spin_kernel_post, self.spin_glm_bootstrap, + self.spin_glm_jobs, self.spin_nknots, self.spin_boots, ): @@ -1504,6 +1541,10 @@ def _connect_signals(self): signal.connect(lambda *_: self._save_temporal_settings()) for edit in (self.edit_formula, self.edit_random, self.edit_group_var): edit.editingFinished.connect(self._save_temporal_settings) + if hasattr(self, "list_kernel_filter"): + self.list_kernel_filter.itemChanged.connect(self._on_kernel_filter_changed) + self.btn_kernel_all.clicked.connect(lambda: self._set_all_kernels_visible(True)) + self.btn_kernel_none.clicked.connect(lambda: self._set_all_kernels_visible(False)) def _load_temporal_settings(self) -> None: self._loading_settings = True @@ -1520,6 +1561,7 @@ def _load_temporal_settings(self) -> None: self.spin_kernel_pre.setValue(float(self._settings.value(prefix + "kernel_pre", self.spin_kernel_pre.value()))) self.spin_kernel_post.setValue(float(self._settings.value(prefix + "kernel_post", self.spin_kernel_post.value()))) self.spin_glm_bootstrap.setValue(int(self._settings.value(prefix + "glm_shift_bootstraps", self.spin_glm_bootstrap.value()))) + self.spin_glm_jobs.setValue(int(self._settings.value(prefix + "glm_bootstrap_jobs", self.spin_glm_jobs.value()))) self.edit_formula.setText(str(self._settings.value(prefix + "flmm_formula", self.edit_formula.text()) or "Y.obs ~ 1")) self.edit_random.setText(str(self._settings.value(prefix + "flmm_random", self.edit_random.text()) or "~1")) self.edit_group_var.setText(str(self._settings.value(prefix + "flmm_group_var", self.edit_group_var.text()) or "subject")) @@ -1542,6 +1584,7 @@ def _save_temporal_settings(self) -> None: self._settings.setValue(prefix + "kernel_pre", self.spin_kernel_pre.value()) self._settings.setValue(prefix + "kernel_post", self.spin_kernel_post.value()) self._settings.setValue(prefix + "glm_shift_bootstraps", self.spin_glm_bootstrap.value()) + self._settings.setValue(prefix + "glm_bootstrap_jobs", self.spin_glm_jobs.value()) self._settings.setValue(prefix + "flmm_formula", self.edit_formula.text().strip()) self._settings.setValue(prefix + "flmm_random", self.edit_random.text().strip()) self._settings.setValue(prefix + "flmm_group_var", self.edit_group_var.text().strip()) @@ -1660,6 +1703,62 @@ def _compact_feature_label(label: object, max_len: int = 42) -> str: return text[:max_len - 3] + "..." return text + @staticmethod + def _kernel_color(key: object) -> str: + palette = [ + "#4b9df8", "#f5a97f", "#6bdb74", "#ee99a0", "#c6a0f6", + "#89dceb", "#f5c542", "#5fd0c5", "#ff7ab2", "#a6e3a1", + "#fab387", "#74c7ec", "#b4befe", "#f38ba8", "#94e2d5", + "#eba0ac", "#8bd5ca", "#eed49f", "#91d7e3", "#f5bde6", + ] + digest = hashlib.blake2s(str(key).encode("utf-8", errors="replace"), digest_size=2).digest() + return palette[int.from_bytes(digest, "little") % len(palette)] + + def _sync_kernel_filter(self, result: GLMResult) -> None: + if not hasattr(self, "list_kernel_filter"): + return + names = list(result.predictor_names or result.kernels.keys()) + self._kernel_filter_guard = True + try: + self.list_kernel_filter.clear() + for name in names: + visible = bool(self._kernel_visible.get(name, True)) + self._kernel_visible[name] = visible + label = self._compact_feature_label(self._predictor_label(name), 34) + item = QtWidgets.QListWidgetItem(label) + item.setData(QtCore.Qt.ItemDataRole.UserRole, name) + item.setFlags(item.flags() | QtCore.Qt.ItemFlag.ItemIsUserCheckable) + item.setCheckState(QtCore.Qt.CheckState.Checked if visible else QtCore.Qt.CheckState.Unchecked) + item.setForeground(QtGui.QColor(self._kernel_color(name))) + self.list_kernel_filter.addItem(item) + finally: + self._kernel_filter_guard = False + + def _on_kernel_filter_changed(self, item: QtWidgets.QListWidgetItem) -> None: + if self._kernel_filter_guard: + return + key = item.data(QtCore.Qt.ItemDataRole.UserRole) + if isinstance(key, str) and key: + self._kernel_visible[key] = item.checkState() == QtCore.Qt.CheckState.Checked + if self._glm_result is not None: + self._plot_glm_kernels(self._glm_result, refresh_filter=False) + + def _set_all_kernels_visible(self, visible: bool) -> None: + if not hasattr(self, "list_kernel_filter"): + return + self._kernel_filter_guard = True + try: + for i in range(self.list_kernel_filter.count()): + item = self.list_kernel_filter.item(i) + key = item.data(QtCore.Qt.ItemDataRole.UserRole) + if isinstance(key, str) and key: + self._kernel_visible[key] = bool(visible) + item.setCheckState(QtCore.Qt.CheckState.Checked if visible else QtCore.Qt.CheckState.Unchecked) + finally: + self._kernel_filter_guard = False + if self._glm_result is not None: + self._plot_glm_kernels(self._glm_result, refresh_filter=False) + def _selected_predictor_keys(self) -> List[str]: keys: List[str] = [] for i in range(self.list_predictors.count()): @@ -2314,8 +2413,13 @@ def _compute_glm_shift_bootstrap_significance( segment_slices = list(dataset.get("segment_slices", []) or []) rng = np.random.default_rng() total = max(1, int(n_boot) * len(rows)) - self._progress_start("GLM circular-shift test", total) + max_jobs = max(1, min(int(getattr(self, "spin_glm_jobs", None).value() if hasattr(self, "spin_glm_jobs") else 1), os.cpu_count() or 1)) + n_basis = int(self.spin_n_basis.value()) + label = f"GLM circular-shift test ({max_jobs} job{'s' if max_jobs != 1 else ''})" + self._progress_start(label, total) step = 0 + work_items: List[Tuple[int, str, np.ndarray, float, int]] = [] + row_meta: Dict[int, Tuple[Dict[str, Any], float]] = {} for row in rows: feature = str(row.get("feature", "")) obs_delta = float(row.get("delta_r2", float("nan"))) @@ -2332,38 +2436,64 @@ def _compute_glm_shift_bootstrap_significance( row["significant"] = False row["bootstrap_n"] = 0 step += int(n_boot) - self._progress_update(step, "GLM circular-shift test") + self._progress_update(step, label) continue base_vec = ContinuousGLM._predictor_vector(time, spec) - null_deltas: List[float] = [] - for _ in range(int(n_boot)): - shifted_predictors = dict(predictors) - shifted_predictors[feature] = { - "kind": "vector", - "values": self._shift_vector_by_segment(base_vec, segment_slices, rng), - } - try: - shifted = ContinuousGLM().fit( - time, - signal, - shifted_predictors, - kernel_window=kernel_window, - n_basis=self.spin_n_basis.value(), - basis_type=basis_type, - regularization=regularization, - alpha=alpha, - ) - null_delta = float(shifted.r2 - reduced_r2) - if np.isfinite(null_delta): - null_deltas.append(null_delta) - except Exception as exc: - _LOG.debug("Circular-shift GLM bootstrap failed for %s: %s", feature, exc) - step += 1 - if step == total or step % 5 == 0: - self._progress_update(step, "GLM circular-shift test") - - null_arr = np.asarray(null_deltas, float) + row_idx = id(row) + row_meta[row_idx] = (row, obs_delta) + for seed in rng.integers(0, np.iinfo(np.uint32).max, size=int(n_boot), dtype=np.uint32): + work_items.append((row_idx, feature, base_vec, reduced_r2, int(seed))) + + def _one_shift_fit(job: Tuple[int, str, np.ndarray, float, int]) -> Tuple[int, float]: + row_idx, feature, base_vec, reduced_r2, seed = job + local_rng = np.random.default_rng(seed) + shifted_predictors = dict(predictors) + shifted_predictors[feature] = { + "kind": "vector", + "values": self._shift_vector_by_segment(base_vec, segment_slices, local_rng), + } + shifted = ContinuousGLM().fit( + time, + signal, + shifted_predictors, + kernel_window=kernel_window, + n_basis=n_basis, + basis_type=basis_type, + regularization=regularization, + alpha=alpha, + ) + return row_idx, float(shifted.r2 - reduced_r2) + + null_by_row: Dict[int, List[float]] = {row_idx: [] for row_idx in row_meta} + if work_items: + if max_jobs == 1: + for job in work_items: + try: + row_idx, null_delta = _one_shift_fit(job) + if np.isfinite(null_delta): + null_by_row[row_idx].append(null_delta) + except Exception as exc: + _LOG.debug("Circular-shift GLM bootstrap failed: %s", exc) + step += 1 + if step == total or step % 5 == 0: + self._progress_update(step, label) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_jobs) as executor: + futures = [executor.submit(_one_shift_fit, job) for job in work_items] + for future in concurrent.futures.as_completed(futures): + try: + row_idx, null_delta = future.result() + if np.isfinite(null_delta): + null_by_row[row_idx].append(null_delta) + except Exception as exc: + _LOG.debug("Circular-shift GLM bootstrap failed: %s", exc) + step += 1 + if step == total or step % 5 == 0: + self._progress_update(step, label) + + for row_idx, (row, obs_delta) in row_meta.items(): + null_arr = np.asarray(null_by_row.get(row_idx, []), float) if null_arr.size: p_value = float((1 + np.sum(null_arr >= obs_delta)) / (null_arr.size + 1)) row["p_value"] = p_value @@ -2657,6 +2787,7 @@ def _fit_glm_catalog(self) -> None: result.feature_importance = importance_rows used_labels = [self._predictor_label(k) for k in result.predictor_names] + n_jobs = int(self.spin_glm_jobs.value()) if hasattr(self, "spin_glm_jobs") else 1 dropped_predictors = dataset.get("dropped_predictors", []) or [] used_records = dataset.get("used_records", []) or [] dropped_records = dataset.get("dropped_records", []) or [] @@ -2670,7 +2801,7 @@ def _fit_glm_catalog(self) -> None: f"Predictors: {', '.join(used_labels)}", f"Basis: {self.combo_basis.currentText()}, n={self.spin_n_basis.value()}", f"Regularization: {self.combo_reg.currentText()}, alpha={self.spin_alpha.value():.3f}", - f"Circular-shift bootstraps: {n_boot if n_boot > 0 else 'off'}", + f"Circular-shift bootstraps: {n_boot if n_boot > 0 else 'off'} ({n_jobs} job{'s' if n_jobs != 1 else ''})", ] stats = result.stats or {} lines.extend([ @@ -2785,18 +2916,29 @@ def _fit_glm(self): self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) self.statusMessage.emit(f"GLM fit complete — R² = {result.r2:.4f}", 5000) - def _plot_glm_kernels(self, result: GLMResult): + def _plot_glm_kernels(self, result: GLMResult, refresh_filter: bool = True): pw = self.plot_kernel + if refresh_filter: + self._sync_kernel_filter(result) pw.clear() try: pw.getPlotItem().legend.clear() except Exception: pass - colors = ["#4b9df8", "#f5a97f", "#6bdb74", "#ee99a0", "#c6a0f6", - "#f5e0dc", "#89dceb", "#fab387"] - for i, (name, kernel) in enumerate(result.kernels.items()): - color = colors[i % len(colors)] + plotted = 0 + for name in result.predictor_names: + if not self._kernel_visible.get(name, True): + continue + kernel = result.kernels.get(name) + if kernel is None: + continue + color = self._kernel_color(name) pw.plot(result.kernel_tvec, kernel, pen=pg.mkPen(color, width=2), name=self._predictor_label(name)) + plotted += 1 + if plotted == 0: + txt = pg.TextItem("No kernels selected.", color="#c5d2e3") + pw.addItem(txt) + txt.setPos(0, 0) pw.setLabel("bottom", "Time", units="s") pw.setLabel("left", "Kernel weight") # Zero line From 01626e209274a079a7ac90b7ee53e85142de9710 Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 14:37:30 +0200 Subject: [PATCH 10/14] Add GLM illustration overlay and fix FLMM names --- pyBer/temporal_modeling.py | 222 ++++++++++++++++++++++++++++++++++++- 1 file changed, 221 insertions(+), 1 deletion(-) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index 33f7120..ff8ad0c 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -915,6 +915,7 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._saved_predictor_keys: List[str] = [] self._kernel_visible: Dict[str, bool] = {} self._kernel_filter_guard = False + self._illustration_vb: Optional[pg.ViewBox] = None self._glm = ContinuousGLM() self._flmm = TrialFLMM() self._glm_result: Optional[GLMResult] = None @@ -1428,6 +1429,33 @@ def _build_workspace_pages(self): prediction_lay.addWidget(self.plot_prediction, 1) self.tabs_workspace.addTab(prediction_page, "Prediction") + illustration_page = QtWidgets.QWidget() + illustration_lay = QtWidgets.QVBoxLayout(illustration_page) + illustration_lay.setContentsMargins(10, 10, 10, 10) + controls = QtWidgets.QHBoxLayout() + controls.setSpacing(8) + lbl_feature = QtWidgets.QLabel("Overlay feature") + lbl_feature.setProperty("class", "muted") + controls.addWidget(lbl_feature) + self.combo_illustration_feature = QtWidgets.QComboBox() + self.combo_illustration_feature.setMinimumWidth(260) + controls.addWidget(self.combo_illustration_feature) + self.lbl_illustration_stats = QtWidgets.QLabel("") + self.lbl_illustration_stats.setProperty("class", "muted") + controls.addWidget(self.lbl_illustration_stats, 1) + illustration_lay.addLayout(controls) + self.plot_illustration = pg.PlotWidget(title="Signal and selected feature contribution") + self._style_plot(self.plot_illustration) + pi = self.plot_illustration.getPlotItem() + pi.showAxis("right") + self._illustration_vb = pg.ViewBox() + pi.scene().addItem(self._illustration_vb) + pi.getAxis("right").linkToView(self._illustration_vb) + self._illustration_vb.setXLink(pi) + pi.vb.sigResized.connect(self._update_illustration_view) + illustration_lay.addWidget(self.plot_illustration, 1) + self.tabs_workspace.addTab(illustration_page, "Illustration") + residual_page = QtWidgets.QWidget() residual_lay = QtWidgets.QVBoxLayout(residual_page) residual_lay.setContentsMargins(10, 10, 10, 10) @@ -1545,6 +1573,8 @@ def _connect_signals(self): self.list_kernel_filter.itemChanged.connect(self._on_kernel_filter_changed) self.btn_kernel_all.clicked.connect(lambda: self._set_all_kernels_visible(True)) self.btn_kernel_none.clicked.connect(lambda: self._set_all_kernels_visible(False)) + if hasattr(self, "combo_illustration_feature"): + self.combo_illustration_feature.currentIndexChanged.connect(self._on_illustration_feature_changed) def _load_temporal_settings(self) -> None: self._loading_settings = True @@ -1759,6 +1789,98 @@ def _set_all_kernels_visible(self, visible: bool) -> None: if self._glm_result is not None: self._plot_glm_kernels(self._glm_result, refresh_filter=False) + def _update_illustration_view(self) -> None: + if not hasattr(self, "plot_illustration") or self._illustration_vb is None: + return + plot_item = self.plot_illustration.getPlotItem() + self._illustration_vb.setGeometry(plot_item.vb.sceneBoundingRect()) + self._illustration_vb.linkedViewChanged(plot_item.vb, self._illustration_vb.XAxis) + + def _glm_feature_order(self, result: GLMResult) -> List[str]: + ordered: List[str] = [] + for row in result.feature_importance or []: + key = str(row.get("feature", "") or "") + if key in result.predictor_names and key not in ordered: + ordered.append(key) + for key in result.predictor_names: + if key not in ordered: + ordered.append(key) + return ordered + + def _sync_illustration_features(self, result: GLMResult) -> None: + if not hasattr(self, "combo_illustration_feature"): + return + current = self.combo_illustration_feature.currentData(QtCore.Qt.ItemDataRole.UserRole) + self.combo_illustration_feature.blockSignals(True) + try: + self.combo_illustration_feature.clear() + for key in self._glm_feature_order(result): + self.combo_illustration_feature.addItem(self._predictor_label(key), key) + if isinstance(current, str) and current: + idx = self.combo_illustration_feature.findData(current, QtCore.Qt.ItemDataRole.UserRole) + if idx >= 0: + self.combo_illustration_feature.setCurrentIndex(idx) + finally: + self.combo_illustration_feature.blockSignals(False) + + @staticmethod + def _pearson_stats(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, int]: + xv = np.asarray(x, float) + yv = np.asarray(y, float) + m = np.isfinite(xv) & np.isfinite(yv) + xv = xv[m] + yv = yv[m] + if xv.size < 3 or np.nanstd(xv) <= 1e-12 or np.nanstd(yv) <= 1e-12: + return float("nan"), float("nan"), int(xv.size) + try: + from scipy import stats as _stats + res = _stats.pearsonr(xv, yv) + return float(res.statistic), float(res.pvalue), int(xv.size) + except Exception: + return float(np.corrcoef(xv, yv)[0, 1]), float("nan"), int(xv.size) + + @staticmethod + def _p_label(p_value: float) -> str: + if not np.isfinite(p_value): + return "p = n/a" + if p_value < 1e-4: + return "p < 1e-4" + return f"p = {p_value:.3g}" + + @staticmethod + def _p_stars(p_value: float) -> str: + if not np.isfinite(p_value): + return "" + if p_value < 0.001: + return "***" + if p_value < 0.01: + return "**" + if p_value < 0.05: + return "*" + return "n.s." + + def _glm_feature_contribution(self, result: GLMResult, key: str) -> Optional[np.ndarray]: + if key not in result.predictor_names: + return None + n_pred = len(result.predictor_names) + if n_pred <= 0 or result.design_matrix is None or result.coefficients is None: + return None + n_basis = (int(np.asarray(result.coefficients).size) - 1) // n_pred + if n_basis <= 0: + return None + pred_idx = result.predictor_names.index(key) + lo = 1 + pred_idx * n_basis + hi = lo + n_basis + X = np.asarray(result.design_matrix, float) + beta = np.asarray(result.coefficients, float) + if X.ndim != 2 or hi > X.shape[1] or hi > beta.size: + return None + return X[:, lo:hi] @ beta[lo:hi] + + def _on_illustration_feature_changed(self, *_args) -> None: + if self._glm_result is not None: + self._plot_glm_illustration(self._glm_result) + def _selected_predictor_keys(self) -> List[str]: keys: List[str] = [] for i in range(self.list_predictors.count()): @@ -2176,6 +2298,17 @@ def _safe_design_name(label: str, used: set[str]) -> str: used.add(name) return name + @staticmethod + def _flmm_term_name(index: int, used: set[str]) -> str: + base = f"pyber_x{max(1, int(index)):06d}z" + name = base + i = 2 + while name in used: + name = f"{base}_{i}" + i += 1 + used.add(name) + return name + def _flmm_matrix_and_labels(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[str], str]: self._flmm_row_meta = [] if self._per_file_mats: @@ -2350,7 +2483,10 @@ def _build_flmm_design( continue values[~finite] = mean values = (values - mean) / std - term = self._safe_design_name(self._predictor_label(key), used_names) + # fastFMM detects functional covariates by shared name prefixes. + # Human labels such as "Distance" and "Distance to point" collide, + # so use neutral non-prefixing IDs and keep labels separately. + term = self._flmm_term_name(len(terms) + 1, used_names) design[term] = values terms.append(term) term_labels[term] = self._predictor_label(key) @@ -2838,6 +2974,7 @@ def _fit_glm_catalog(self) -> None: self._plot_glm_kernels(result) self._plot_glm_fit(result) + self._plot_glm_illustration(result) self._plot_feature_importance( importance_rows, value_key="delta_r2", @@ -2912,6 +3049,7 @@ def _fit_glm(self): # Plot kernels self._plot_glm_kernels(result) self._plot_glm_fit(result) + self._plot_glm_illustration(result) if hasattr(self, "tabs_workspace"): self.tabs_workspace.setCurrentWidget(self.plot_kernel.parentWidget()) self.statusMessage.emit(f"GLM fit complete — R² = {result.r2:.4f}", 5000) @@ -2965,6 +3103,88 @@ def _plot_glm_fit(self, result: GLMResult): rw.setLabel("bottom", "Time", units="s") rw.setLabel("left", "Residual") + def _plot_glm_illustration(self, result: GLMResult) -> None: + if not hasattr(self, "plot_illustration") or self._illustration_vb is None: + return + self._sync_illustration_features(result) + key = self.combo_illustration_feature.currentData(QtCore.Qt.ItemDataRole.UserRole) + if not isinstance(key, str) or not key: + ordered = self._glm_feature_order(result) + key = ordered[0] if ordered else "" + + pw = self.plot_illustration + vb = self._illustration_vb + pw.clear() + vb.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + pw.setTitle("Signal and selected feature contribution") + if not key: + self.lbl_illustration_stats.setText("No fitted GLM feature is available.") + return + + x = np.asarray(result.time, float) + signal = np.asarray(result.y_actual, float) + contribution = self._glm_feature_contribution(result, key) + if contribution is None: + self.lbl_illustration_stats.setText("No contribution trace is available for the selected feature.") + return + contribution = np.asarray(contribution, float) + n = min(x.size, signal.size, contribution.size) + x = x[:n] + signal = signal[:n] + contribution = contribution[:n] + valid = np.isfinite(x) & np.isfinite(signal) & np.isfinite(contribution) + r_value, p_value, n_corr = self._pearson_stats(signal[valid], contribution[valid]) + p_text = self._p_label(p_value) + stars = self._p_stars(p_value) + stats_text = f"Pearson r = {r_value:.3f}, {p_text}, n = {n_corr}" + if stars: + stats_text += f" ({stars})" + self.lbl_illustration_stats.setText(stats_text) + + signal_color = "#4b9df8" + feature_color = self._kernel_color(key) + pw.plot(x, signal, pen=pg.mkPen(signal_color, width=1.25), name="signal") + feat_curve = pg.PlotDataItem(x, contribution, pen=pg.mkPen(feature_color, width=1.8), name=self._predictor_label(key)) + vb.addItem(feat_curve) + + plot_item = pw.getPlotItem() + plot_item.getAxis("right").setPen(pg.mkPen(feature_color)) + plot_item.getAxis("right").setTextPen(pg.mkPen(feature_color)) + plot_item.getAxis("left").setPen(pg.mkPen(signal_color)) + plot_item.getAxis("left").setTextPen(pg.mkPen(signal_color)) + plot_item.setLabel("left", "Signal") + plot_item.setLabel("right", "Feature contribution") + plot_item.setLabel("bottom", "Time", units="s") + pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + + finite_signal = signal[np.isfinite(signal)] + finite_feature = contribution[np.isfinite(contribution)] + if finite_signal.size: + y0 = float(np.nanmin(finite_signal)) + y1 = float(np.nanmax(finite_signal)) + pad = max((y1 - y0) * 0.08, 1e-9) + pw.setYRange(y0 - pad, y1 + pad, padding=0.0) + if finite_feature.size: + f0 = float(np.nanmin(finite_feature)) + f1 = float(np.nanmax(finite_feature)) + pad = max((f1 - f0) * 0.08, 1e-9) + vb.setYRange(f0 - pad, f1 + pad, padding=0.0) + self._update_illustration_view() + + finite_x = x[np.isfinite(x)] + if finite_x.size and finite_signal.size: + xr0 = float(np.nanmin(finite_x)) + xr1 = float(np.nanmax(finite_x)) + yr0 = float(np.nanmin(finite_signal)) + yr1 = float(np.nanmax(finite_signal)) + txt = pg.TextItem(stats_text, color="#e9f0fb", anchor=(0.0, 0.0)) + pw.addItem(txt) + txt.setPos(xr0 + 0.02 * max(xr1 - xr0, 1e-9), yr1 - 0.08 * max(yr1 - yr0, 1e-9)) + def _plot_feature_importance( self, rows: List[Dict[str, Any]], From 290afc58283ffbfc4e545e13d2be74c7dc4929de Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 15:10:23 +0200 Subject: [PATCH 11/14] Stabilize FLMM fitting for singular designs --- pyBer/temporal_modeling.py | 78 +++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index ff8ad0c..e91f94a 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -604,6 +604,7 @@ def fit( kwargs = { "formula": ro.Formula(formula_text), "data": r_df, + "var": ro.BoolVector([True]), "parallel": ro.BoolVector([parallel]), "silent": ro.BoolVector([True]), "subj_id": ro.StrVector([group_var_model]), @@ -624,7 +625,27 @@ def fit( _LOG.info("Calling fastFMM::fui() with formula=%s, %d trials, %d timepoints", formula_text, n_trials, n_time) - fui_result = fastFMM.fui(**kwargs) + fit_notes: List[str] = [] + try: + fui_result = fastFMM.fui(**kwargs) + except Exception as exc: + msg = str(exc) + recoverable = any( + token in msg.lower() + for token in ("downdated vtv", "not positive definite", "singular", "rank deficient") + ) + if not recoverable: + raise + retry_kwargs = dict(kwargs) + retry_kwargs["var"] = ro.BoolVector([False]) + retry_kwargs["analytic"] = ro.BoolVector([True]) + retry_kwargs["n_boots"] = ro.IntVector([0]) + fit_notes.append( + "fastFMM variance inference failed because the mixed-model design was near-singular; " + "refit with variance/CIs disabled." + ) + _LOG.warning("Retrying fastFMM without variance inference after: %s", msg) + fui_result = fastFMM.fui(**retry_kwargs) # Parse the result. # fui() returns a list with elements: @@ -719,6 +740,8 @@ def fit( summary_parts = [f"FLMM fit: {len(term_names)} terms, {n_trials} trials, {n_time} timepoints"] if fallback_grouping: summary_parts.append(f"Note: {fallback_grouping}") + for note in fit_notes: + summary_parts.append(f"Note: {note}") if aic_val is not None: summary_parts.append(f"AIC = {aic_val:.1f}") coeff_abs_peaks: List[float] = [] @@ -743,6 +766,7 @@ def fit( "group_var": group_var_model, "requested_group_var": group_var, "fallback_grouping": fallback_grouping, + "fit_notes": list(fit_notes), } except Exception as exc: @@ -2492,6 +2516,41 @@ def _build_flmm_design( term_labels[term] = self._predictor_label(key) return design, terms, dropped, term_labels + def _prune_flmm_terms( + self, + design: Dict[str, np.ndarray], + terms: List[str], + term_labels: Dict[str, str], + n_rows: int, + ) -> Tuple[List[str], List[str]]: + if not terms: + return [], [] + max_terms = max(0, min(len(terms), int(n_rows) - 2)) + if max_terms <= 0: + return [], [f"{term_labels.get(term, term)} (not enough rows)" for term in terms] + + X = np.ones((int(n_rows), 1), float) + rank = int(np.linalg.matrix_rank(X)) + kept: List[str] = [] + dropped: List[str] = [] + for term in terms: + if len(kept) >= max_terms: + dropped.append(f"{term_labels.get(term, term)} (too many predictors for {n_rows} rows)") + continue + col = np.asarray(design.get(term, np.array([], float)), float).reshape(-1) + if col.size != n_rows or not np.all(np.isfinite(col)): + dropped.append(f"{term_labels.get(term, term)} (invalid values)") + continue + candidate = np.column_stack([X, col]) + new_rank = int(np.linalg.matrix_rank(candidate)) + if new_rank > rank: + kept.append(term) + X = candidate + rank = new_rank + else: + dropped.append(f"{term_labels.get(term, term)} (collinear with existing predictors)") + return kept, dropped + @staticmethod def _intercept_only_fit_stats(signal: np.ndarray) -> Dict[str, float]: signal = np.asarray(signal, float) @@ -3275,9 +3334,11 @@ def _fit_flmm_group(self) -> None: row_labels = [f"{scope}_{i + 1}" for i in range(n_rows)] design, auto_terms, dropped_terms, term_labels = self._build_flmm_design(row_labels, group_var) + fit_terms, pruned_terms = self._prune_flmm_terms(design, auto_terms, term_labels, n_rows) + if pruned_terms: + dropped_terms.extend(pruned_terms) requested_formula = self.edit_formula.text().strip() - auto_formula = "Y.obs ~ " + " + ".join(auto_terms) if auto_terms else "Y.obs ~ 1" requested_terms = [ term for term in self._simple_formula_terms(requested_formula) if "|" not in term and "(" not in term and ")" not in term @@ -3289,7 +3350,14 @@ def _fit_flmm_group(self) -> None: or not requested_formula.lstrip().startswith("Y.obs") or bool(missing_terms) ) - formula = auto_formula if use_auto else requested_formula + if use_auto: + formula_terms = list(fit_terms) + else: + formula_terms = [term for term in requested_terms if term in fit_terms] + removed_manual = [term for term in requested_terms if term in design and term not in fit_terms] + if removed_manual: + missing_terms.extend(removed_manual) + formula = "Y.obs ~ " + " + ".join(formula_terms) if formula_terms else "Y.obs ~ 1" if requested_formula != formula: self.edit_formula.setText(formula) random_eff = self.edit_random.text().strip() or "~1" @@ -3372,8 +3440,8 @@ def _fit_flmm_group(self) -> None: summary_lines.append(f" {len(failed)} reduced FLMM fits failed; see log for details.") if num_boots > 0: summary_lines.append(" Leave-one-out comparison uses analytic fits; bootstrap CIs are not repeated.") - if auto_terms: - readable = [f"{term} = {term_labels.get(term, term)}" for term in auto_terms] + if fit_terms: + readable = [f"{term} = {term_labels.get(term, term)}" for term in fit_terms] summary_lines.append("Covariates: " + "; ".join(readable)) if dropped_terms: summary_lines.append("Dropped covariates: " + ", ".join(dropped_terms)) From 4c5f09be7b385ab1af9a194616463683ed761d4c Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 15:53:18 +0200 Subject: [PATCH 12/14] Make FLMM contribution analysis optional --- pyBer/temporal_modeling.py | 170 ++++++++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 42 deletions(-) diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index e91f94a..e71760a 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -600,6 +600,7 @@ def fit( formula_text = f"{formula_text} + ({rand_rhs} | {group_var_model})" # Call fui() + base = importr("base") fastFMM = importr("fastFMM") kwargs = { "formula": ro.Formula(formula_text), @@ -626,26 +627,34 @@ def fit( formula_text, n_trials, n_time) fit_notes: List[str] = [] + old_warn = R("getOption('warn')") + base.options(warn=ro.IntVector([-1])) try: - fui_result = fastFMM.fui(**kwargs) - except Exception as exc: - msg = str(exc) - recoverable = any( - token in msg.lower() - for token in ("downdated vtv", "not positive definite", "singular", "rank deficient") - ) - if not recoverable: - raise - retry_kwargs = dict(kwargs) - retry_kwargs["var"] = ro.BoolVector([False]) - retry_kwargs["analytic"] = ro.BoolVector([True]) - retry_kwargs["n_boots"] = ro.IntVector([0]) - fit_notes.append( - "fastFMM variance inference failed because the mixed-model design was near-singular; " - "refit with variance/CIs disabled." - ) - _LOG.warning("Retrying fastFMM without variance inference after: %s", msg) - fui_result = fastFMM.fui(**retry_kwargs) + try: + fui_result = fastFMM.fui(**kwargs) + except Exception as exc: + msg = str(exc) + recoverable = any( + token in msg.lower() + for token in ("downdated vtv", "not positive definite", "singular", "rank deficient") + ) + if not recoverable: + raise + retry_kwargs = dict(kwargs) + retry_kwargs["var"] = ro.BoolVector([False]) + retry_kwargs["analytic"] = ro.BoolVector([True]) + retry_kwargs["n_boots"] = ro.IntVector([0]) + fit_notes.append( + "fastFMM variance inference failed because the mixed-model design was near-singular; " + "refit with variance/CIs disabled." + ) + _LOG.warning("Retrying fastFMM without variance inference after: %s", msg) + fui_result = fastFMM.fui(**retry_kwargs) + finally: + try: + base.options(warn=old_warn) + except Exception: + base.options(warn=ro.IntVector([0])) # Parse the result. # fui() returns a list with elements: @@ -1078,6 +1087,14 @@ def _build_ui(self): self.spin_boots.setValue(0) self.spin_boots.setSpecialValueText("analytic") fl.addRow("Bootstrap iter:", self.spin_boots) + self.combo_flmm_importance = QtWidgets.QComboBox() + self.combo_flmm_importance.addItem("Fast coefficient ranking", "fast") + self.combo_flmm_importance.addItem("Leave-one-out AIC (slow)", "loo") + self.combo_flmm_importance.addItem("Off", "off") + self.combo_flmm_importance.setToolTip( + "Leave-one-out refits fastFMM once per predictor and can be very slow." + ) + fl.addRow("Contribution:", self.combo_flmm_importance) root.addWidget(self.grp_flmm) @@ -1349,6 +1366,14 @@ def _build_model_page(self): self.spin_boots.setValue(0) self.spin_boots.setSpecialValueText("analytic") fl.addRow("Bootstrap iter", self.spin_boots) + self.combo_flmm_importance = QtWidgets.QComboBox() + self.combo_flmm_importance.addItem("Fast coefficient ranking", "fast") + self.combo_flmm_importance.addItem("Leave-one-out AIC (slow)", "loo") + self.combo_flmm_importance.addItem("Off", "off") + self.combo_flmm_importance.setToolTip( + "Leave-one-out refits fastFMM once per predictor and can be very slow." + ) + fl.addRow("Contribution", self.combo_flmm_importance) lay.addWidget(self.grp_flmm) lay.addStretch(1) self.stack_controls.addWidget(page) @@ -1579,6 +1604,7 @@ def _connect_signals(self): for widget in ( self.combo_basis, self.combo_reg, + self.combo_flmm_importance, self.spin_n_basis, self.spin_alpha, self.spin_kernel_pre, @@ -1621,6 +1647,12 @@ def _load_temporal_settings(self) -> None: self.edit_group_var.setText(str(self._settings.value(prefix + "flmm_group_var", self.edit_group_var.text()) or "subject")) self.spin_nknots.setValue(int(self._settings.value(prefix + "flmm_nknots", self.spin_nknots.value()))) self.spin_boots.setValue(int(self._settings.value(prefix + "flmm_boots", self.spin_boots.value()))) + mode = str(self._settings.value(prefix + "flmm_importance_mode", "fast") or "fast") + idx = self.combo_flmm_importance.findData(mode, QtCore.Qt.ItemDataRole.UserRole) + if idx < 0: + idx = self.combo_flmm_importance.findText(mode) + if idx >= 0: + self.combo_flmm_importance.setCurrentIndex(idx) raw_predictors = str(self._settings.value(prefix + "predictor_keys", "") or "") self._saved_predictor_keys = [key for key in raw_predictors.split("\n") if key.strip()] finally: @@ -1644,6 +1676,7 @@ def _save_temporal_settings(self) -> None: self._settings.setValue(prefix + "flmm_group_var", self.edit_group_var.text().strip()) self._settings.setValue(prefix + "flmm_nknots", self.spin_nknots.value()) self._settings.setValue(prefix + "flmm_boots", self.spin_boots.value()) + self._settings.setValue(prefix + "flmm_importance_mode", self.combo_flmm_importance.currentData(QtCore.Qt.ItemDataRole.UserRole) or "fast") predictors = self._selected_predictor_keys() if hasattr(self, "list_predictors") else self._saved_predictor_keys if predictors or getattr(self, "_predictor_catalog", None): self._saved_predictor_keys = list(predictors) @@ -2790,19 +2823,53 @@ def _simple_formula_terms(formula: str) -> List[str]: @staticmethod def _term_mean_abs_coefficient(result: FLMMResult, term: str) -> float: - if not result.coefficients: + coeff = TemporalModelingWidget._term_coefficient_curve(result, term) + if coeff is None: return float("nan") + vals = np.asarray(coeff, float) + return float(np.nanmean(np.abs(vals))) if vals.size else float("nan") + + @staticmethod + def _term_coefficient_curve(result: FLMMResult, term: str) -> Optional[np.ndarray]: + if not result.coefficients: + return None clean = re.sub(r"[^0-9A-Za-z_]+", "", str(term).lower()) for name, coeff in result.coefficients.items(): if str(name) == str(term): - vals = np.asarray(coeff, float) - return float(np.nanmean(np.abs(vals))) if vals.size else float("nan") + return np.asarray(coeff, float) for name, coeff in result.coefficients.items(): name_clean = re.sub(r"[^0-9A-Za-z_]+", "", str(name).lower()) if clean and (clean in name_clean or name_clean in clean): - vals = np.asarray(coeff, float) - return float(np.nanmean(np.abs(vals))) if vals.size else float("nan") - return float("nan") + return np.asarray(coeff, float) + return None + + def _compute_flmm_coefficient_importance( + self, + result: FLMMResult, + terms: List[str], + term_labels: Dict[str, str], + ) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + for term in terms: + coeff = self._term_coefficient_curve(result, term) + if coeff is None: + continue + vals = np.asarray(coeff, float) + if vals.size == 0: + continue + rows.append({ + "feature": term, + "label": term_labels.get(term, term), + "mean_abs_coefficient": float(np.nanmean(np.abs(vals))), + "peak_abs_coefficient": float(np.nanmax(np.abs(vals))), + "delta_aic": float("nan"), + "status": "ok", + }) + rows.sort(key=lambda item: ( + np.isfinite(item.get("mean_abs_coefficient", np.nan)), + float(item.get("mean_abs_coefficient", -np.inf)) if np.isfinite(item.get("mean_abs_coefficient", np.nan)) else -np.inf, + ), reverse=True) + return rows def _compute_flmm_leave_one_out( self, @@ -3265,7 +3332,7 @@ def _plot_feature_importance( if np.isfinite(float(row.get(value_key, float("nan")))) ] if not usable: - txt = pg.TextItem("No leave-one-out feature contribution is available.", color="#c5d2e3") + txt = pg.TextItem("No feature contribution is available.", color="#c5d2e3") pw.addItem(txt) txt.setPos(0, 0) pw.setLabel("bottom", y_label) @@ -3378,23 +3445,29 @@ def _fit_flmm_group(self) -> None: ) self._flmm_result = result - self.statusMessage.emit("Calculating FLMM leave-one-feature-out AIC contribution...", 0) - QtWidgets.QApplication.processEvents() - importance_rows = self._compute_flmm_leave_one_out( - mat, - tvec, - design, - formula, - random_eff, - group_var, - nknots, - result, - term_labels, - ) + importance_mode = str(self.combo_flmm_importance.currentData(QtCore.Qt.ItemDataRole.UserRole) or "fast") + if importance_mode == "loo": + self.statusMessage.emit("Calculating FLMM leave-one-feature-out AIC contribution...", 0) + QtWidgets.QApplication.processEvents() + importance_rows = self._compute_flmm_leave_one_out( + mat, + tvec, + design, + formula, + random_eff, + group_var, + nknots, + result, + term_labels, + ) + elif importance_mode == "fast": + importance_rows = self._compute_flmm_coefficient_importance(result, formula_terms, term_labels) + else: + importance_rows = [] result.feature_importance = importance_rows importance_value_key = ( "delta_aic" - if any(np.isfinite(float(row.get("delta_aic", float("nan")))) for row in importance_rows) + if importance_mode == "loo" and any(np.isfinite(float(row.get("delta_aic", float("nan")))) for row in importance_rows) else "mean_abs_coefficient" ) @@ -3421,7 +3494,7 @@ def _fit_flmm_group(self) -> None: f" mean abs coefficient = {result.stats.get('mean_abs_coefficient', float('nan')):.5g}", f" peak abs coefficient = {result.stats.get('peak_abs_coefficient', float('nan')):.5g}", ]) - if importance_rows: + if importance_rows and importance_mode == "loo": summary_lines.extend([ "", "Leave-one-feature-out contribution (reduced AIC - full AIC):", @@ -3440,6 +3513,19 @@ def _fit_flmm_group(self) -> None: summary_lines.append(f" {len(failed)} reduced FLMM fits failed; see log for details.") if num_boots > 0: summary_lines.append(" Leave-one-out comparison uses analytic fits; bootstrap CIs are not repeated.") + elif importance_rows and importance_mode == "fast": + summary_lines.extend([ + "", + "Fast coefficient contribution (single FLMM fit):", + ]) + for row in importance_rows[:10]: + summary_lines.append( + f" {row['label']}: mean abs coef = {row['mean_abs_coefficient']:.5g}, " + f"peak abs coef = {row['peak_abs_coefficient']:.5g}" + ) + summary_lines.append(" Leave-one-out AIC is disabled by default because it refits fastFMM once per predictor.") + elif importance_mode == "off": + summary_lines.extend(["", "Feature contribution: off."]) if fit_terms: readable = [f"{term} = {term_labels.get(term, term)}" for term in fit_terms] summary_lines.append("Covariates: " + "; ".join(readable)) @@ -3450,7 +3536,7 @@ def _fit_flmm_group(self) -> None: self._plot_feature_importance( importance_rows, value_key=importance_value_key, - title="FLMM leave-one-feature-out contribution" if importance_value_key == "delta_aic" else "FLMM coefficient contribution", + title="FLMM leave-one-feature-out contribution" if importance_mode == "loo" else "FLMM coefficient contribution", y_label="Delta AIC" if importance_value_key == "delta_aic" else "Mean abs coefficient", ) if hasattr(self, "tabs_workspace") and importance_rows: From b4c5d9315de6758566ffec9ccf766b2e3489be2b Mon Sep 17 00:00:00 2001 From: andrianj Date: Fri, 8 May 2026 18:08:32 +0200 Subject: [PATCH 13/14] Improve temporal modeling group workflow --- pyBer/gui_postprocessing.py | 18 + pyBer/temporal_modeling.py | 954 +++++++++++++++++++++++++++++++++--- 2 files changed, 905 insertions(+), 67 deletions(-) diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index 7cb5627..91995e7 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -7499,6 +7499,13 @@ def _restore_cached_analysis_outputs(self, payload: Dict[str, object]) -> None: self.last_behavior_analysis = behavior_analysis self._render_behavior_analysis_outputs() + temporal_state = payload.get("temporal_modeling") + if isinstance(temporal_state, dict) and temporal_state and hasattr(self, "section_temporal"): + try: + self.section_temporal.deserialize_state(temporal_state) + except Exception: + _LOG.debug("Could not restore temporal modeling state", exc_info=True) + def _save_project_h5(self, path: str) -> None: with h5py.File(path, "w") as f: f.attrs["project_type"] = "pyber_postprocessing_project" @@ -7610,6 +7617,13 @@ def _save_project_h5(self, path: str) -> None: analysis_group = f.create_group("analysis") self._save_signal_events_h5(analysis_group) self._save_behavior_analysis_h5(analysis_group) + try: + if hasattr(self, "section_temporal") and self.section_temporal is not None: + state = self.section_temporal.serialize_state() + if state: + self._write_h5_json_any(analysis_group, "temporal_modeling_json", state) + except Exception: + _LOG.debug("Could not save temporal modeling state to project", exc_info=True) def _load_project_h5(self, path: str) -> Dict[str, object]: payload: Dict[str, object] = { @@ -7768,6 +7782,10 @@ def _aligned(values: Optional[np.ndarray], fill_nan: bool = True) -> np.ndarray: if isinstance(analysis_group, h5py.Group): payload["signal_events"] = self._load_signal_events_h5(analysis_group) payload["behavior_analysis"] = self._load_behavior_analysis_h5(analysis_group) + try: + payload["temporal_modeling"] = self._read_h5_json_any(analysis_group, "temporal_modeling_json", {}) + except Exception: + payload["temporal_modeling"] = {} payload["processed"] = loaded_processed payload["behavior_sources"] = loaded_behavior diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index e71760a..30df44d 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -954,6 +954,16 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None): self._glm_result: Optional[GLMResult] = None self._flmm_result: Optional[FLMMResult] = None + # Per-file fits (filled by batch / group runs). + self._glm_results_by_file: Dict[str, GLMResult] = {} + self._flmm_results_by_file: Dict[str, FLMMResult] = {} + self._fit_summary_by_file: Dict[str, str] = {} + self._group_glm_summary: Dict[str, Any] = {} + # Batch state + self._batch_cancel_requested: bool = False + self._fit_mode: str = "all" # one of "active" | "all" | "batch" + self._active_file_id: str = "" + # Data references (set by host panel) self._processed_trials = [] self._psth_mat: Optional[np.ndarray] = None @@ -1209,21 +1219,74 @@ def _build_compact_ui(self): self.combo_model_type = QtWidgets.QComboBox() self.combo_model_type.addItems(["Continuous GLM", "Trial-level FLMM (fastFMM)"]) - self.combo_model_type.setMinimumWidth(230) + self.combo_model_type.setMinimumWidth(220) h.addWidget(self.combo_model_type) - self.btn_fit = QtWidgets.QPushButton("Fit model") + self.btn_fit = QtWidgets.QPushButton("Fit") self.btn_fit.setProperty("class", "primary") - self.btn_fit.setMinimumWidth(120) + self.btn_fit.setMinimumWidth(86) + self.btn_fit.setToolTip("Fit the model using the current scope (Active file / All / Per-file batch).") h.addWidget(self.btn_fit) + self.btn_cancel = QtWidgets.QPushButton("Cancel") + self.btn_cancel.setMinimumWidth(78) + self.btn_cancel.setEnabled(False) + self.btn_cancel.setToolTip("Stop the current batch fit (single fits run synchronously).") + h.addWidget(self.btn_cancel) + self.progress_model = QtWidgets.QProgressBar() - self.progress_model.setMinimumWidth(260) - self.progress_model.setMaximumWidth(360) + self.progress_model.setMinimumWidth(220) + self.progress_model.setMaximumWidth(320) self.progress_model.setVisible(False) h.addWidget(self.progress_model) root.addWidget(header) + # ---- Scope strip: Active / All / Per-file (batch) + file selector ---- + scope_bar = QtWidgets.QFrame() + scope_bar.setObjectName("temporalScopeBar") + sb = QtWidgets.QHBoxLayout(scope_bar) + sb.setContentsMargins(14, 4, 14, 6) + sb.setSpacing(10) + + scope_lbl = QtWidgets.QLabel("Scope") + scope_lbl.setProperty("class", "muted") + sb.addWidget(scope_lbl) + + self.combo_fit_scope = QtWidgets.QComboBox() + self.combo_fit_scope.addItem("Active file (single)", "active") + self.combo_fit_scope.addItem("All loaded (concatenated)", "all") + self.combo_fit_scope.addItem("Per-file batch + group", "batch") + self.combo_fit_scope.setMinimumWidth(220) + self.combo_fit_scope.setToolTip( + "Active: fit the selected animal only.\n" + "All: concatenate every loaded recording into one GLM.\n" + "Per-file batch: fit each animal independently, then aggregate for the Group tab." + ) + sb.addWidget(self.combo_fit_scope) + + file_lbl = QtWidgets.QLabel("File") + file_lbl.setProperty("class", "muted") + sb.addWidget(file_lbl) + + self.combo_active_file = QtWidgets.QComboBox() + self.combo_active_file.setMinimumWidth(280) + self.combo_active_file.setToolTip("Active animal/recording (used when Scope = Active file).") + sb.addWidget(self.combo_active_file, 1) + + self.btn_prev_file = QtWidgets.QToolButton() + self.btn_prev_file.setText("◀") + self.btn_prev_file.setToolTip("Previous file") + self.btn_next_file = QtWidgets.QToolButton() + self.btn_next_file.setText("▶") + self.btn_next_file.setToolTip("Next file") + sb.addWidget(self.btn_prev_file) + sb.addWidget(self.btn_next_file) + + self.lbl_fit_state = QtWidgets.QLabel("No fit yet") + self.lbl_fit_state.setProperty("class", "muted") + sb.addWidget(self.lbl_fit_state) + root.addWidget(scope_bar) + split = QtWidgets.QSplitter(QtCore.Qt.Orientation.Horizontal) split.setChildrenCollapsible(False) split.setHandleWidth(6) @@ -1242,8 +1305,9 @@ def _build_compact_ui(self): nav_lay.setSpacing(8) self.btn_nav_model = self._make_nav_button("Model") self.btn_nav_predictors = self._make_nav_button("Predictors") + self.btn_nav_files = self._make_nav_button("Files\n& Group") self.btn_nav_fit = self._make_nav_button("Fit") - for btn in (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_fit): + for btn in (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_files, self.btn_nav_fit): nav_lay.addWidget(btn) nav_lay.addStretch(1) left_lay.addWidget(nav) @@ -1267,14 +1331,22 @@ def _build_compact_ui(self): self._build_model_page() self._build_predictor_page() + self._build_files_page() self._build_fit_page() self._build_workspace_pages() self.btn_nav_model.setChecked(True) self.btn_nav_model.clicked.connect(lambda: self._select_control_page(0)) self.btn_nav_predictors.clicked.connect(lambda: self._select_control_page(1)) - self.btn_nav_fit.clicked.connect(lambda: self._select_control_page(2)) + self.btn_nav_files.clicked.connect(lambda: self._select_control_page(2)) + self.btn_nav_fit.clicked.connect(lambda: self._select_control_page(3)) self.btn_fit_side.clicked.connect(self._on_fit_clicked) + self.btn_fit_all_files.clicked.connect(self._on_fit_all_files_clicked) + self.btn_cancel.clicked.connect(self._on_cancel_clicked) + self.combo_fit_scope.currentIndexChanged.connect(self._on_fit_scope_changed) + self.combo_active_file.currentIndexChanged.connect(self._on_active_file_changed) + self.btn_prev_file.clicked.connect(lambda: self._step_active_file(-1)) + self.btn_next_file.clicked.connect(lambda: self._step_active_file(1)) self._on_model_type_changed(0) def _build_model_page(self): @@ -1414,13 +1486,57 @@ def _build_predictor_page(self): lay.addWidget(self.grp_predictors, 1) self.stack_controls.addWidget(page) + def _build_files_page(self): + page = QtWidgets.QWidget() + lay = QtWidgets.QVBoxLayout(page) + lay.setContentsMargins(0, 0, 0, 0) + lay.setSpacing(8) + + grp = QtWidgets.QGroupBox("Loaded recordings") + gl = QtWidgets.QVBoxLayout(grp) + gl.setContentsMargins(12, 18, 12, 12) + gl.setSpacing(8) + + hint = QtWidgets.QLabel( + "Select an animal to make it the active recording. The Group tab aggregates " + "per-file fits once you run a Per-file batch." + ) + hint.setProperty("class", "muted") + hint.setWordWrap(True) + gl.addWidget(hint) + + self.list_files = QtWidgets.QListWidget() + self.list_files.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.SingleSelection) + self.list_files.setMinimumHeight(220) + gl.addWidget(self.list_files, 1) + + self.btn_fit_all_files = QtWidgets.QPushButton("Fit each file (per-file batch)") + self.btn_fit_all_files.setProperty("class", "primary") + gl.addWidget(self.btn_fit_all_files) + + self.lbl_batch_status = QtWidgets.QLabel("") + self.lbl_batch_status.setProperty("class", "muted") + self.lbl_batch_status.setWordWrap(True) + gl.addWidget(self.lbl_batch_status) + + self.btn_clear_fits = QtWidgets.QPushButton("Clear cached fits") + self.btn_clear_fits.setToolTip("Discard all cached per-file results and group aggregates.") + gl.addWidget(self.btn_clear_fits) + try: + self.btn_clear_fits.clicked.connect(self._on_clear_cached_fits) + except Exception: + pass + lay.addWidget(grp, 1) + self.stack_controls.addWidget(page) + self.list_files.itemSelectionChanged.connect(self._on_file_list_selection_changed) + def _build_fit_page(self): page = QtWidgets.QWidget() lay = QtWidgets.QVBoxLayout(page) lay.setContentsMargins(0, 0, 0, 0) lay.setSpacing(8) - grp = QtWidgets.QGroupBox("Fit Control") + grp = QtWidgets.QGroupBox("Fit control") gl = QtWidgets.QVBoxLayout(grp) gl.setContentsMargins(12, 18, 12, 12) gl.setSpacing(10) @@ -1428,9 +1544,25 @@ def _build_fit_page(self): self.lbl_data_status.setProperty("class", "muted") self.lbl_data_status.setWordWrap(True) gl.addWidget(self.lbl_data_status) - self.btn_fit_side = QtWidgets.QPushButton("Fit model") + + # Mode-aware Fit button (the same handler as the top-bar button). + self.btn_fit_side = QtWidgets.QPushButton("Fit current scope") self.btn_fit_side.setProperty("class", "primary") + self.btn_fit_side.setToolTip( + "Run the model on the current scope. Equivalent to the Fit button in the top bar." + ) gl.addWidget(self.btn_fit_side) + + contrib_lay = QtWidgets.QHBoxLayout() + self.chk_run_contrib = QtWidgets.QCheckBox("Run leave-one-predictor-out contribution") + self.chk_run_contrib.setChecked(True) + self.chk_run_contrib.setToolTip( + "Disable to skip the per-predictor reduced-fit comparison (saves time on large designs)." + ) + contrib_lay.addWidget(self.chk_run_contrib) + contrib_lay.addStretch(1) + gl.addLayout(contrib_lay) + gl.addStretch(1) lay.addWidget(grp, 1) self.stack_controls.addWidget(page) @@ -1487,8 +1619,25 @@ def _build_workspace_pages(self): lbl_feature.setProperty("class", "muted") controls.addWidget(lbl_feature) self.combo_illustration_feature = QtWidgets.QComboBox() - self.combo_illustration_feature.setMinimumWidth(260) + self.combo_illustration_feature.setMinimumWidth(220) controls.addWidget(self.combo_illustration_feature) + + self.chk_show_signal = QtWidgets.QCheckBox("Signal") + self.chk_show_signal.setChecked(True) + self.chk_show_predicted = QtWidgets.QCheckBox("Predicted") + self.chk_show_predicted.setChecked(False) + self.chk_show_contribution = QtWidgets.QCheckBox("Contribution") + self.chk_show_contribution.setChecked(True) + self.chk_show_raw_predictor = QtWidgets.QCheckBox("Raw predictor") + self.chk_show_raw_predictor.setChecked(False) + for chk in ( + self.chk_show_signal, + self.chk_show_predicted, + self.chk_show_contribution, + self.chk_show_raw_predictor, + ): + controls.addWidget(chk) + chk.toggled.connect(self._on_illustration_overlay_toggled) self.lbl_illustration_stats = QtWidgets.QLabel("") self.lbl_illustration_stats.setProperty("class", "muted") controls.addWidget(self.lbl_illustration_stats, 1) @@ -1529,6 +1678,27 @@ def _build_workspace_pages(self): flmm_lay.addWidget(self.plot_coeff, 1) self.tabs_workspace.addTab(flmm_page, "FLMM") + # Group tab — aggregates per-file fits. + group_page = QtWidgets.QWidget() + group_lay = QtWidgets.QVBoxLayout(group_page) + group_lay.setContentsMargins(10, 10, 10, 10) + group_top = QtWidgets.QHBoxLayout() + group_top.setSpacing(8) + self.lbl_group_summary = QtWidgets.QLabel( + "Run a Per-file batch fit to populate the Group view." + ) + self.lbl_group_summary.setProperty("class", "muted") + self.lbl_group_summary.setWordWrap(True) + group_top.addWidget(self.lbl_group_summary, 1) + group_lay.addLayout(group_top) + self.plot_group_kernels = pg.PlotWidget(title="Group kernels (mean +/- SEM across animals)") + self._style_plot(self.plot_group_kernels) + group_lay.addWidget(self.plot_group_kernels, 1) + self.plot_group_importance = pg.PlotWidget(title="Group leave-one-out contribution (mean across animals)") + self._style_plot(self.plot_group_importance) + group_lay.addWidget(self.plot_group_importance, 1) + self.tabs_workspace.addTab(group_page, "Group") + def _make_nav_button(self, text: str) -> QtWidgets.QToolButton: btn = QtWidgets.QToolButton() btn.setText(text) @@ -1540,7 +1710,7 @@ def _make_nav_button(self, text: str) -> QtWidgets.QToolButton: def _select_control_page(self, index: int) -> None: self.stack_controls.setCurrentIndex(index) - buttons = (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_fit) + buttons = (self.btn_nav_model, self.btn_nav_predictors, self.btn_nav_files, self.btn_nav_fit) for i, btn in enumerate(buttons): btn.setChecked(i == index) @@ -1719,6 +1889,13 @@ def set_data( self._visual_mode = int(visual_mode) if isinstance(visual_mode, (int, np.integer)) else 0 self._group_mode = bool(group_mode) self._refresh_predictor_catalog() + self._refresh_file_widgets() + # Drop cached fits for files that disappeared. + live_ids = {self._proc_file_id(p, fallback=f"file_{i + 1}") for i, p in enumerate(self._processed_trials)} + for fid in list(self._glm_results_by_file): + if fid not in live_ids: + self._glm_results_by_file.pop(fid, None) + self._fit_summary_by_file.pop(fid, None) n_trials = len(self._processed_trials) psth_shape = tuple(np.shape(psth_mat)) if psth_mat is not None else None @@ -1735,6 +1912,128 @@ def set_data( bits.append(f"Available predictors: {len(self._predictor_catalog)}") self.lbl_data_status.setText("\n".join(bits)) + # ------------------------------------------------------------------ + # State serialization (used by project save/load) + # ------------------------------------------------------------------ + + @staticmethod + def _serialize_glm_result(r: GLMResult) -> Dict[str, Any]: + return { + "predictor_names": list(r.predictor_names or []), + "kernels": {k: np.asarray(v, float).tolist() for k, v in (r.kernels or {}).items()}, + "kernel_tvec": np.asarray(r.kernel_tvec, float).tolist(), + "time": np.asarray(r.time, float).tolist() if r.time is not None else [], + "y_pred": np.asarray(r.y_pred, float).tolist() if r.y_pred is not None else [], + "y_actual": np.asarray(r.y_actual, float).tolist() if r.y_actual is not None else [], + "residuals": np.asarray(r.residuals, float).tolist() if r.residuals is not None else [], + "r2": float(r.r2), + "coefficients": np.asarray(r.coefficients, float).tolist() if r.coefficients is not None else [], + "stats": {k: (float(v) if isinstance(v, (int, float, np.floating, np.integer)) else v) for k, v in (r.stats or {}).items()}, + "feature_importance": [ + {k: (float(v) if isinstance(v, (int, float, np.floating, np.integer)) else v) + for k, v in (row or {}).items()} + for row in (r.feature_importance or []) + ], + } + + @staticmethod + def _deserialize_glm_result(payload: Dict[str, Any]) -> Optional[GLMResult]: + if not isinstance(payload, dict): + return None + try: + kernels = {k: np.asarray(v, float) for k, v in (payload.get("kernels", {}) or {}).items()} + return GLMResult( + predictor_names=list(payload.get("predictor_names", []) or []), + kernels=kernels, + kernel_tvec=np.asarray(payload.get("kernel_tvec", []), float), + time=np.asarray(payload.get("time", []), float), + y_pred=np.asarray(payload.get("y_pred", []), float), + y_actual=np.asarray(payload.get("y_actual", []), float), + residuals=np.asarray(payload.get("residuals", []), float), + r2=float(payload.get("r2", 0.0)), + coefficients=np.asarray(payload.get("coefficients", []), float), + # design_matrix is intentionally not persisted (large, redundant); + # contribution overlays will simply be unavailable until refit. + design_matrix=np.zeros((0, 0), float), + stats=dict(payload.get("stats", {}) or {}), + feature_importance=list(payload.get("feature_importance", []) or []), + ) + except Exception as exc: + _LOG.warning("Could not deserialize GLM result: %s", exc) + return None + + def serialize_state(self) -> Dict[str, Any]: + """Return a JSON-serializable snapshot of the latest GLM/FLMM fits.""" + state: Dict[str, Any] = { + "version": 1, + "fit_mode": self._fit_mode, + "active_file_id": self._active_file_id, + "selected_predictors": self._selected_predictor_keys() if hasattr(self, "list_predictors") else [], + "settings": { + "model_type": int(self.combo_model_type.currentIndex()) if hasattr(self, "combo_model_type") else 0, + "basis": self.combo_basis.currentText() if hasattr(self, "combo_basis") else "", + "n_basis": int(self.spin_n_basis.value()) if hasattr(self, "spin_n_basis") else 8, + "regularization": self.combo_reg.currentText() if hasattr(self, "combo_reg") else "", + "alpha": float(self.spin_alpha.value()) if hasattr(self, "spin_alpha") else 1.0, + "kernel_pre": float(self.spin_kernel_pre.value()) if hasattr(self, "spin_kernel_pre") else -1.0, + "kernel_post": float(self.spin_kernel_post.value()) if hasattr(self, "spin_kernel_post") else 3.0, + }, + "glm_results_by_file": { + fid: self._serialize_glm_result(res) + for fid, res in self._glm_results_by_file.items() + }, + "fit_summary_by_file": dict(self._fit_summary_by_file), + "group_summary": dict(self._group_glm_summary or {}), + } + if self._glm_result is not None: + state["glm_result_active"] = self._serialize_glm_result(self._glm_result) + return state + + def deserialize_state(self, state: Dict[str, Any]) -> None: + """Restore a snapshot produced by serialize_state().""" + if not isinstance(state, dict): + return + try: + self._fit_mode = str(state.get("fit_mode", "all")) + if hasattr(self, "combo_fit_scope"): + idx = self.combo_fit_scope.findData(self._fit_mode) + if idx >= 0: + self.combo_fit_scope.blockSignals(True) + self.combo_fit_scope.setCurrentIndex(idx) + self.combo_fit_scope.blockSignals(False) + self._active_file_id = str(state.get("active_file_id", "") or "") + + by_file = state.get("glm_results_by_file", {}) or {} + self._glm_results_by_file = {} + for fid, payload in by_file.items(): + res = self._deserialize_glm_result(payload) + if res is not None: + self._glm_results_by_file[str(fid)] = res + self._fit_summary_by_file = dict(state.get("fit_summary_by_file", {}) or {}) + self._group_glm_summary = dict(state.get("group_summary", {}) or {}) + + active_payload = state.get("glm_result_active") + if active_payload: + res = self._deserialize_glm_result(active_payload) + if res is not None: + self._glm_result = res + self.txt_summary.setPlainText(self._fit_summary_by_file.get(self._active_file_id, "")) + self._plot_glm_kernels(res) + self._plot_glm_fit(res) + self._plot_glm_illustration(res) + self._plot_feature_importance( + res.feature_importance or [], + value_key="delta_r2", + title="GLM leave-one-predictor-out contribution", + y_label="Drop in R^2", + ) + if self._glm_results_by_file: + self._aggregate_group_results() + self._refresh_file_widgets() + self._update_fit_state_label() + except Exception as exc: + _LOG.warning("Could not restore temporal modeling state: %s", exc) + # ------------------------------------------------------------------ # Predictor catalog and extraction # ------------------------------------------------------------------ @@ -2209,7 +2508,9 @@ def _predictor_vector_for_proc(self, key: str, proc: Any, time: np.ndarray) -> T return self._interp_to_time(time, source_time, values), "continuous" return np.zeros(time.size, float), str(entry.get("kind", "event")) - def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: + def _build_glm_dataset_from_selected_predictors( + self, file_filter: Optional[str] = None + ) -> Dict[str, Any]: selected = self._selected_predictor_keys() if not selected: return {"error": "Choose at least one predictor before fitting."} @@ -2224,6 +2525,8 @@ def _build_glm_dataset_from_selected_predictors(self) -> Dict[str, Any]: y_raw = getattr(proc, "output", None) y = np.asarray(y_raw, float) if y_raw is not None else np.array([], float) file_id = self._proc_file_id(proc, fallback=f"file_{idx + 1}") + if file_filter is not None and not self._ids_match(file_id, file_filter) and file_id != file_filter: + continue if t.size < 3 or y.size != t.size: dropped_records.append(file_id) continue @@ -2647,6 +2950,60 @@ def _compute_glm_shift_bootstrap_significance( self._progress_start(label, total) step = 0 work_items: List[Tuple[int, str, np.ndarray, float, int]] = [] + + # Pre-compute the basis matrix and per-predictor design columns ONCE, + # so each bootstrap only rebuilds the columns of the shifted predictor. + try: + dt = float(np.nanmedian(np.diff(time))) + except Exception: + dt = 0.0 + if not np.isfinite(dt) or dt <= 0: + dt = 1.0 + pre_samp = int(round(abs(kernel_window[0]) / dt)) + post_samp = int(round(abs(kernel_window[1]) / dt)) + kernel_len = max(2, pre_samp + post_samp) + if basis_type == "bspline": + basis_mat = _bspline_basis(n_basis, kernel_len) + elif basis_type == "fir": + basis_mat = _fir_basis(n_basis, kernel_len) + else: + basis_mat = _raised_cosine_basis(n_basis, kernel_len) + + T_full = int(time.size) + + def _columns_for_vector(input_vec: np.ndarray) -> Optional[np.ndarray]: + v = np.asarray(input_vec, float) + if v.size != T_full or not np.any(np.abs(v) > 1e-12): + return None + cols = np.zeros((T_full, n_basis), float) + for b in range(n_basis): + kernel = basis_mat[:, b] + conv = np.convolve(v, kernel, mode="full")[:T_full] + if pre_samp > 0: + conv = np.roll(conv, -pre_samp) + conv[-pre_samp:] = 0.0 + cols[:, b] = conv + return cols + + # Cache columns for every predictor in its un-shifted form. + static_cols_cache: Dict[str, np.ndarray] = {} + for name, spec in predictors.items(): + v = ContinuousGLM._predictor_vector(time, spec) + cols = _columns_for_vector(v) + if cols is not None: + static_cols_cache[name] = cols + used_predictors = list(static_cols_cache.keys()) + intercept_col = np.ones((T_full, 1), float) + valid_mask = np.isfinite(signal) + signal_v = signal[valid_mask] + ss_tot_full = float(np.nansum((signal - np.nanmean(signal)) ** 2)) + alpha_val = float(alpha) + + def _ridge_solve(X: np.ndarray, y: np.ndarray) -> np.ndarray: + n_cols = X.shape[1] + I = np.eye(n_cols) + I[0, 0] = 0.0 + return np.linalg.solve(X.T @ X + alpha_val * I, X.T @ y) row_meta: Dict[int, Tuple[Dict[str, Any], float]] = {} for row in rows: feature = str(row.get("feature", "")) @@ -2674,24 +3031,49 @@ def _compute_glm_shift_bootstrap_significance( work_items.append((row_idx, feature, base_vec, reduced_r2, int(seed))) def _one_shift_fit(job: Tuple[int, str, np.ndarray, float, int]) -> Tuple[int, float]: + """ + Fast circular-shift refit. Only the shifted predictor's columns + are recomputed; columns for the other predictors come from the + shared static cache. Falls back to the full GLM path if the cache + for this feature is unavailable. + """ row_idx, feature, base_vec, reduced_r2, seed = job local_rng = np.random.default_rng(seed) - shifted_predictors = dict(predictors) - shifted_predictors[feature] = { - "kind": "vector", - "values": self._shift_vector_by_segment(base_vec, segment_slices, local_rng), - } - shifted = ContinuousGLM().fit( - time, - signal, - shifted_predictors, - kernel_window=kernel_window, - n_basis=n_basis, - basis_type=basis_type, - regularization=regularization, - alpha=alpha, - ) - return row_idx, float(shifted.r2 - reduced_r2) + shifted_vec = self._shift_vector_by_segment(base_vec, segment_slices, local_rng) + if feature not in static_cols_cache or regularization not in ("ridge", "ols"): + shifted_predictors = dict(predictors) + shifted_predictors[feature] = {"kind": "vector", "values": shifted_vec} + shifted = ContinuousGLM().fit( + time, signal, shifted_predictors, + kernel_window=kernel_window, n_basis=n_basis, + basis_type=basis_type, regularization=regularization, + alpha=alpha, + ) + return row_idx, float(shifted.r2 - reduced_r2) + + shifted_cols = _columns_for_vector(shifted_vec) + if shifted_cols is None: + return row_idx, float("nan") + blocks = [intercept_col] + for name in used_predictors: + if name == feature: + blocks.append(shifted_cols) + else: + blocks.append(static_cols_cache[name]) + X = np.hstack(blocks) + Xv = X[valid_mask] + try: + if regularization == "ols": + beta, *_ = np.linalg.lstsq(Xv, signal_v, rcond=None) + else: + beta = _ridge_solve(Xv, signal_v) + except np.linalg.LinAlgError: + return row_idx, float("nan") + y_pred = X @ beta + residuals = signal - y_pred + ss_res = float(np.nansum(residuals ** 2)) + r2 = 1.0 - ss_res / max(ss_tot_full, 1e-12) + return row_idx, float(r2 - reduced_r2) null_by_row: Dict[int, List[float]] = {row_idx: [] for row_idx in row_meta} if work_items: @@ -2975,12 +3357,190 @@ def _on_remove_predictor(self): self.list_predictors.takeItem(sel) self._save_temporal_settings() + # ------------------------------------------------------------------ + # Scope (Active / All / Per-file batch) handling + # ------------------------------------------------------------------ + + def _refresh_file_widgets(self) -> None: + """Populate combo_active_file and list_files from currently-loaded recordings.""" + if not hasattr(self, "combo_active_file"): + return + ids: List[Tuple[str, str]] = [] + for idx, proc in enumerate(self._processed_trials): + file_id = self._proc_file_id(proc, fallback=f"file_{idx + 1}") + label = file_id + ids.append((file_id, label)) + + # Combo + self.combo_active_file.blockSignals(True) + try: + self.combo_active_file.clear() + for fid, label in ids: + self.combo_active_file.addItem(label, fid) + if not ids: + self.combo_active_file.addItem("No files loaded", "") + target = self._active_file_id + if target: + idx = self.combo_active_file.findData(target) + if idx >= 0: + self.combo_active_file.setCurrentIndex(idx) + elif ids: + self.combo_active_file.setCurrentIndex(0) + self._active_file_id = ids[0][0] + elif ids: + self.combo_active_file.setCurrentIndex(0) + self._active_file_id = ids[0][0] + finally: + self.combo_active_file.blockSignals(False) + + # List + if hasattr(self, "list_files"): + self.list_files.blockSignals(True) + try: + self.list_files.clear() + for fid, label in ids: + item = QtWidgets.QListWidgetItem(label) + item.setData(QtCore.Qt.ItemDataRole.UserRole, fid) + if fid in self._glm_results_by_file: + item.setText(f"{label} [fit cached]") + item.setForeground(QtGui.QColor("#6bdb74")) + self.list_files.addItem(item) + # Sync selection + for i in range(self.list_files.count()): + if self.list_files.item(i).data(QtCore.Qt.ItemDataRole.UserRole) == self._active_file_id: + self.list_files.setCurrentRow(i) + break + finally: + self.list_files.blockSignals(False) + self._update_fit_state_label() + + def _step_active_file(self, delta: int) -> None: + if not hasattr(self, "combo_active_file"): + return + n = self.combo_active_file.count() + if n == 0: + return + idx = max(0, min(n - 1, self.combo_active_file.currentIndex() + int(delta))) + self.combo_active_file.setCurrentIndex(idx) + + def _on_active_file_changed(self, *_): + if not hasattr(self, "combo_active_file"): + return + data = self.combo_active_file.currentData() + if isinstance(data, str) and data: + self._active_file_id = data + # Mirror selection in the list + if hasattr(self, "list_files"): + for i in range(self.list_files.count()): + if self.list_files.item(i).data(QtCore.Qt.ItemDataRole.UserRole) == data: + if self.list_files.currentRow() != i: + self.list_files.blockSignals(True) + self.list_files.setCurrentRow(i) + self.list_files.blockSignals(False) + break + # If we are in Active scope and a fit is cached for this file, render it. + if self._fit_mode == "active": + cached = self._glm_results_by_file.get(self._active_file_id) + if cached is not None: + self._glm_result = cached + self.txt_summary.setPlainText(self._fit_summary_by_file.get(self._active_file_id, "")) + self._plot_glm_kernels(cached) + self._plot_glm_fit(cached) + self._plot_glm_illustration(cached) + self._plot_feature_importance( + cached.feature_importance or [], + value_key="delta_r2", + title=f"GLM contribution - {self._active_file_id}", + y_label="Drop in R^2", + ) + self._update_fit_state_label() + + def _on_file_list_selection_changed(self) -> None: + if not hasattr(self, "list_files"): + return + item = self.list_files.currentItem() + if item is None: + return + fid = item.data(QtCore.Qt.ItemDataRole.UserRole) + if isinstance(fid, str) and fid and fid != self._active_file_id: + idx = self.combo_active_file.findData(fid) + if idx >= 0: + self.combo_active_file.setCurrentIndex(idx) + + def _on_fit_scope_changed(self, *_): + data = self.combo_fit_scope.currentData() + if isinstance(data, str): + self._fit_mode = data + self._update_fit_state_label() + + def _update_fit_state_label(self) -> None: + if not hasattr(self, "lbl_fit_state"): + return + n_cached = len(self._glm_results_by_file) + n_files = len(self._processed_trials) + bits = [] + if self._glm_result is not None: + bits.append(f"R^2 = {self._glm_result.r2:.3f}") + if n_cached > 0: + bits.append(f"{n_cached}/{n_files} files fit") + if not bits: + bits.append("No fit yet") + self.lbl_fit_state.setText(" | ".join(bits)) + + def _on_clear_cached_fits(self) -> None: + self._glm_results_by_file.clear() + self._flmm_results_by_file.clear() + self._fit_summary_by_file.clear() + self._group_glm_summary = {} + self._refresh_file_widgets() + if hasattr(self, "lbl_batch_status"): + self.lbl_batch_status.setText("Cleared cached per-file fits.") + if hasattr(self, "lbl_group_summary"): + self.lbl_group_summary.setText( + "Run a Per-file batch fit to populate the Group view." + ) + if hasattr(self, "plot_group_kernels"): + self.plot_group_kernels.clear() + if hasattr(self, "plot_group_importance"): + self.plot_group_importance.clear() + + def _on_cancel_clicked(self): + self._batch_cancel_requested = True + self.btn_cancel.setEnabled(False) + self.statusMessage.emit("Cancelling current batch...", 3000) + + def _on_fit_all_files_clicked(self): + if not self._processed_trials: + self.statusMessage.emit("No recordings loaded.", 5000) + return + # Force scope to per-file batch and run. + idx = self.combo_fit_scope.findData("batch") + if idx >= 0: + self.combo_fit_scope.setCurrentIndex(idx) + self._fit_mode = "batch" + self._on_fit_clicked() + + def _on_illustration_overlay_toggled(self, *_): + if self._glm_result is not None: + self._plot_glm_illustration(self._glm_result) + def _on_fit_clicked(self): model_idx = self.combo_model_type.currentIndex() + self._batch_cancel_requested = False self._set_fit_enabled(False) + if hasattr(self, "btn_cancel"): + self.btn_cancel.setEnabled(self._fit_mode == "batch") try: if model_idx == 0: - self._fit_glm() + if self._fit_mode == "active": + if not self._active_file_id: + self.statusMessage.emit("No active file selected.", 5000) + return + self._fit_glm_catalog(file_filter=self._active_file_id) + elif self._fit_mode == "batch": + self._fit_glm_per_file_batch() + else: + self._fit_glm_catalog(file_filter=None) else: self._fit_flmm() except Exception as exc: @@ -2989,15 +3549,18 @@ def _on_fit_clicked(self): self.statusMessage.emit(f"Temporal model fit failed: {exc}", 8000) finally: self._set_fit_enabled(True) + if hasattr(self, "btn_cancel"): + self.btn_cancel.setEnabled(False) self._progress_finish() + self._update_fit_state_label() # ------------------------------------------------------------------ # GLM fit # ------------------------------------------------------------------ - def _fit_glm_catalog(self) -> None: + def _fit_glm_catalog(self, file_filter: Optional[str] = None) -> Optional[GLMResult]: self._save_temporal_settings() - dataset = self._build_glm_dataset_from_selected_predictors() + dataset = self._build_glm_dataset_from_selected_predictors(file_filter=file_filter) if "error" in dataset: msg = str(dataset.get("error", "Could not build GLM dataset.")) dropped = dataset.get("dropped_predictors", []) or [] @@ -3006,14 +3569,15 @@ def _fit_glm_catalog(self) -> None: self.txt_summary.setPlainText(msg) self.statusMessage.emit(msg.splitlines()[0], 7000) self._select_control_page(1) - return + return None basis_map = {"Raised cosine": "raised_cosine", "B-spline": "bspline", "FIR": "fir"} reg_map = {"Ridge": "ridge", "Lasso": "lasso", "OLS": "ols"} kernel_win = (self.spin_kernel_pre.value(), self.spin_kernel_post.value()) basis_type = basis_map.get(self.combo_basis.currentText(), "raised_cosine") regularization = reg_map.get(self.combo_reg.currentText(), "ridge") - self._progress_start("Fitting GLM", 0) + scope_label = file_filter if file_filter else "all files" + self._progress_start(f"Fitting GLM ({scope_label})", 0) result = self._glm.fit( np.asarray(dataset["time"], float), np.asarray(dataset["signal"], float), @@ -3026,26 +3590,31 @@ def _fit_glm_catalog(self) -> None: ) self._glm_result = result - self.statusMessage.emit("Calculating GLM leave-one-predictor-out contribution...", 0) - QtWidgets.QApplication.processEvents() - importance_rows = self._compute_glm_leave_one_out( - dataset, - result, - kernel_win, - basis_type, - regularization, - self.spin_alpha.value(), - ) - n_boot = int(self.spin_glm_bootstrap.value()) - self._compute_glm_shift_bootstrap_significance( - dataset, - importance_rows, - kernel_win, - basis_type, - regularization, - self.spin_alpha.value(), - n_boot, - ) + run_contrib = bool(getattr(self, "chk_run_contrib", None) is None or self.chk_run_contrib.isChecked()) + if run_contrib: + self.statusMessage.emit("Calculating GLM leave-one-predictor-out contribution...", 0) + QtWidgets.QApplication.processEvents() + importance_rows = self._compute_glm_leave_one_out( + dataset, + result, + kernel_win, + basis_type, + regularization, + self.spin_alpha.value(), + ) + n_boot = int(self.spin_glm_bootstrap.value()) + self._compute_glm_shift_bootstrap_significance( + dataset, + importance_rows, + kernel_win, + basis_type, + regularization, + self.spin_alpha.value(), + n_boot, + ) + else: + importance_rows = [] + n_boot = 0 result.feature_importance = importance_rows used_labels = [self._predictor_label(k) for k in result.predictor_names] @@ -3096,7 +3665,14 @@ def _fit_glm_catalog(self) -> None: lines.append(f"Dropped predictors: {', '.join(str(v) for v in dropped_predictors)}") if dropped_records: lines.append(f"Dropped recordings: {', '.join(str(v) for v in dropped_records)}") - self.txt_summary.setPlainText("\n".join(lines)) + summary_text = "\n".join(lines) + self.txt_summary.setPlainText(summary_text) + + # Cache per-file fit if scope was a single file. + if file_filter: + self._glm_results_by_file[file_filter] = result + self._fit_summary_by_file[file_filter] = summary_text + self._refresh_file_widgets() self._plot_glm_kernels(result) self._plot_glm_fit(result) @@ -3110,10 +3686,200 @@ def _fit_glm_catalog(self) -> None: if hasattr(self, "tabs_workspace"): self.tabs_workspace.setCurrentWidget(self.plot_importance.parentWidget() if importance_rows else self.plot_kernel.parentWidget()) self.statusMessage.emit(f"GLM fit complete - R^2 = {result.r2:.4f}", 5000) + return result + + def _fit_glm_per_file_batch(self) -> None: + """Fit each loaded file independently and populate the Group tab.""" + if not self._processed_trials: + self.statusMessage.emit("No recordings loaded.", 5000) + return + file_ids = [ + self._proc_file_id(p, fallback=f"file_{i + 1}") + for i, p in enumerate(self._processed_trials) + ] + n = len(file_ids) + self._progress_start(f"Per-file batch ({n} files)", n) + ok_results: Dict[str, GLMResult] = {} + ok_summaries: Dict[str, str] = {} + for idx, fid in enumerate(file_ids, 1): + if self._batch_cancel_requested: + self.statusMessage.emit(f"Batch cancelled after {idx - 1}/{n}.", 6000) + break + self._progress_update(idx - 1, f"Per-file batch ({idx}/{n})") + if hasattr(self, "lbl_batch_status"): + self.lbl_batch_status.setText(f"Fitting {idx}/{n}: {fid}") + QtWidgets.QApplication.processEvents() + try: + result = self._fit_glm_catalog(file_filter=fid) + except Exception as exc: + _LOG.warning("Per-file fit failed for %s: %s", fid, exc) + result = None + if result is not None: + ok_results[fid] = result + ok_summaries[fid] = self._fit_summary_by_file.get(fid, "") + self._progress_update(idx, f"Per-file batch ({idx}/{n})") + if hasattr(self, "lbl_batch_status"): + self.lbl_batch_status.setText( + f"Batch complete: {len(ok_results)}/{n} files fit successfully." + ) + self._aggregate_group_results() + if hasattr(self, "tabs_workspace") and ok_results: + self.tabs_workspace.setCurrentWidget(self.plot_group_kernels.parentWidget()) + + def _aggregate_group_results(self) -> None: + """Average per-file kernels and importance for the Group tab.""" + if not hasattr(self, "plot_group_kernels"): + return + results = self._glm_results_by_file + self.plot_group_kernels.clear() + self.plot_group_importance.clear() + try: + self.plot_group_kernels.getPlotItem().legend.clear() + except Exception: + pass + try: + self.plot_group_importance.getPlotItem().legend.clear() + except Exception: + pass + if not results: + self.lbl_group_summary.setText( + "Run a Per-file batch fit to populate the Group view." + ) + return + + # Find common predictors and a shared kernel time vector. + predictor_lists = [list(r.predictor_names) for r in results.values()] + common_keys = set(predictor_lists[0]) + for lst in predictor_lists[1:]: + common_keys &= set(lst) + if not common_keys: + self.lbl_group_summary.setText( + "Per-file fits do not share any common predictor; cannot aggregate." + ) + return + + # Use the first result's kernel_tvec; resample others to it if shapes differ. + ref = next(iter(results.values())) + ref_t = np.asarray(ref.kernel_tvec, float) + kernel_stack: Dict[str, List[np.ndarray]] = {k: [] for k in common_keys} + for fid, r in results.items(): + t = np.asarray(r.kernel_tvec, float) + for key in common_keys: + kern = r.kernels.get(key) + if kern is None: + continue + kern = np.asarray(kern, float) + if kern.size != ref_t.size: + if t.size and kern.size == t.size: + kern = np.interp(ref_t, t, kern, left=np.nan, right=np.nan) + else: + continue + kernel_stack[key].append(kern) + + # Plot mean +/- SEM per predictor. + plotted = 0 + for key, stack in kernel_stack.items(): + if not stack: + continue + arr = np.vstack(stack) + mean = np.nanmean(arr, axis=0) + sem = np.nanstd(arr, axis=0, ddof=1) / np.sqrt(max(arr.shape[0], 1)) + color = self._kernel_color(key) + qcol = QtGui.QColor(color) + qcol.setAlpha(60) + lo = mean - sem + hi = mean + sem + self.plot_group_kernels.plot( + ref_t, mean, pen=pg.mkPen(color, width=2), + name=f"{self._predictor_label(key)} (n={arr.shape[0]})", + ) + fill = pg.FillBetweenItem( + pg.PlotDataItem(ref_t, lo), + pg.PlotDataItem(ref_t, hi), + brush=qcol, + ) + self.plot_group_kernels.addItem(fill) + plotted += 1 + self.plot_group_kernels.setLabel("bottom", "Time", units="s") + self.plot_group_kernels.setLabel("left", "Kernel weight") + self.plot_group_kernels.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + self.plot_group_kernels.addLine(x=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + + # Aggregate leave-one-out importance: mean delta_r2 across animals. + importance_acc: Dict[str, List[float]] = {} + importance_labels: Dict[str, str] = {} + for r in results.values(): + for row in r.feature_importance or []: + feat = str(row.get("feature", "")) + if not feat: + continue + val = float(row.get("delta_r2", float("nan"))) + if not np.isfinite(val): + continue + importance_acc.setdefault(feat, []).append(val) + importance_labels[feat] = str(row.get("label", feat) or feat) + agg_rows: List[Dict[str, Any]] = [] + for feat, vals in importance_acc.items(): + arr = np.asarray(vals, float) + agg_rows.append({ + "feature": feat, + "label": importance_labels.get(feat, feat), + "delta_r2": float(np.nanmean(arr)), + "delta_r2_sem": float(np.nanstd(arr, ddof=1) / np.sqrt(max(arr.size, 1))) if arr.size > 1 else 0.0, + "n_animals": int(arr.size), + "significant": False, + }) + agg_rows.sort(key=lambda r: r.get("delta_r2", -np.inf), reverse=True) + self._group_glm_summary = { + "n_files": len(results), + "common_predictors": len(common_keys), + "importance": agg_rows, + "kernel_tvec": ref_t.tolist(), + } + self._render_group_importance(agg_rows) + self.lbl_group_summary.setText( + f"Group GLM aggregate: {len(results)} animals, {len(common_keys)} common predictors. " + f"Kernels show mean +/- SEM across animals." + ) + + def _render_group_importance(self, rows: List[Dict[str, Any]]) -> None: + if not hasattr(self, "plot_group_importance"): + return + pw = self.plot_group_importance + pw.clear() + try: + pw.getPlotItem().legend.clear() + except Exception: + pass + usable = [r for r in rows if np.isfinite(float(r.get("delta_r2", float("nan"))))] + if not usable: + txt = pg.TextItem("No group importance available.", color="#c5d2e3") + pw.addItem(txt) + txt.setPos(0, 0) + return + usable = usable[:25] + vals = np.asarray([float(r.get("delta_r2", 0.0)) for r in usable], float) + sems = np.asarray([float(r.get("delta_r2_sem", 0.0)) for r in usable], float) + y_pos = np.arange(len(usable), dtype=float)[::-1] + brushes = [pg.mkBrush("#4b9df8" if v >= 0 else "#ee99a0") for v in vals] + bar = pg.BarGraphItem(x0=np.zeros_like(vals), x1=vals, y=y_pos, height=0.62, brushes=brushes) + pw.addItem(bar) + # Error bars + err = pg.ErrorBarItem(x=vals, y=y_pos, left=sems, right=sems, beam=0.18, pen=pg.mkPen("#c5d2e3")) + pw.addItem(err) + pw.addLine(x=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) + labels = [] + for pos, row in zip(y_pos, usable): + label = self._compact_feature_label(row.get("label", row.get("feature", "")), 46) + n_an = int(row.get("n_animals", 0)) + labels.append((float(pos), f"{label} (n={n_an})")) + pw.getAxis("left").setTicks([labels]) + pw.setLabel("bottom", "Mean drop in R^2 +/- SEM") + pw.setLabel("left", "Feature") def _fit_glm(self): self._fit_glm_catalog() - return + return # legacy path below kept for reference if not self._processed_trials: self.statusMessage.emit("No processed data — run preprocessing first.", 5000) return @@ -3238,6 +4004,11 @@ def _plot_glm_illustration(self, result: GLMResult) -> None: ordered = self._glm_feature_order(result) key = ordered[0] if ordered else "" + show_signal = bool(getattr(self, "chk_show_signal", None) is None or self.chk_show_signal.isChecked()) + show_pred = bool(getattr(self, "chk_show_predicted", None) is not None and self.chk_show_predicted.isChecked()) + show_contrib = bool(getattr(self, "chk_show_contribution", None) is None or self.chk_show_contribution.isChecked()) + show_raw = bool(getattr(self, "chk_show_raw_predictor", None) is not None and self.chk_show_raw_predictor.isChecked()) + pw = self.plot_illustration vb = self._illustration_vb pw.clear() @@ -3253,15 +4024,41 @@ def _plot_glm_illustration(self, result: GLMResult) -> None: x = np.asarray(result.time, float) signal = np.asarray(result.y_actual, float) + predicted = np.asarray(result.y_pred, float) if result.y_pred is not None else None contribution = self._glm_feature_contribution(result, key) if contribution is None: self.lbl_illustration_stats.setText("No contribution trace is available for the selected feature.") return contribution = np.asarray(contribution, float) - n = min(x.size, signal.size, contribution.size) + + # Raw predictor input (model-time vector) if requested. + raw_predictor: Optional[np.ndarray] = None + if show_raw and result.design_matrix is not None and result.predictor_names: + try: + pred_idx = result.predictor_names.index(key) + n_pred = len(result.predictor_names) + n_basis = max(1, (int(np.asarray(result.coefficients).size) - 1) // n_pred) + lo = 1 + pred_idx * n_basis + # Best proxy for the raw input: the un-convolved indicator (sum across basis cols). + raw_predictor = np.asarray(result.design_matrix[:, lo:lo + n_basis], float).sum(axis=1) + except Exception: + raw_predictor = None + + n = min( + int(x.size), + int(signal.size), + int(contribution.size), + int(predicted.size) if predicted is not None else int(signal.size), + int(raw_predictor.size) if raw_predictor is not None else int(signal.size), + ) x = x[:n] signal = signal[:n] contribution = contribution[:n] + if predicted is not None: + predicted = predicted[:n] + if raw_predictor is not None: + raw_predictor = raw_predictor[:n] + valid = np.isfinite(x) & np.isfinite(signal) & np.isfinite(contribution) r_value, p_value, n_corr = self._pearson_stats(signal[valid], contribution[valid]) p_text = self._p_label(p_value) @@ -3272,10 +4069,26 @@ def _plot_glm_illustration(self, result: GLMResult) -> None: self.lbl_illustration_stats.setText(stats_text) signal_color = "#4b9df8" + predicted_color = "#f5a97f" feature_color = self._kernel_color(key) - pw.plot(x, signal, pen=pg.mkPen(signal_color, width=1.25), name="signal") - feat_curve = pg.PlotDataItem(x, contribution, pen=pg.mkPen(feature_color, width=1.8), name=self._predictor_label(key)) - vb.addItem(feat_curve) + raw_color = "#94e2d5" + + if show_signal: + pw.plot(x, signal, pen=pg.mkPen(signal_color, width=1.25), name="signal") + if show_pred and predicted is not None: + pw.plot(x, predicted, pen=pg.mkPen(predicted_color, width=1.0, style=QtCore.Qt.PenStyle.DashLine), name="predicted") + if show_contrib: + feat_curve = pg.PlotDataItem( + x, contribution, pen=pg.mkPen(feature_color, width=1.8), + name=self._predictor_label(key), + ) + vb.addItem(feat_curve) + if show_raw and raw_predictor is not None: + raw_curve = pg.PlotDataItem( + x, raw_predictor, pen=pg.mkPen(raw_color, width=1.0, style=QtCore.Qt.PenStyle.DotLine), + name=f"{self._predictor_label(key)} (raw)", + ) + vb.addItem(raw_curve) plot_item = pw.getPlotItem() plot_item.getAxis("right").setPen(pg.mkPen(feature_color)) @@ -3283,22 +4096,29 @@ def _plot_glm_illustration(self, result: GLMResult) -> None: plot_item.getAxis("left").setPen(pg.mkPen(signal_color)) plot_item.getAxis("left").setTextPen(pg.mkPen(signal_color)) plot_item.setLabel("left", "Signal") - plot_item.setLabel("right", "Feature contribution") + plot_item.setLabel("right", "Feature contribution / raw") plot_item.setLabel("bottom", "Time", units="s") pw.addLine(y=0, pen=pg.mkPen("#5a6274", width=1, style=QtCore.Qt.PenStyle.DashLine)) finite_signal = signal[np.isfinite(signal)] - finite_feature = contribution[np.isfinite(contribution)] - if finite_signal.size: + if finite_signal.size and (show_signal or show_pred): y0 = float(np.nanmin(finite_signal)) y1 = float(np.nanmax(finite_signal)) pad = max((y1 - y0) * 0.08, 1e-9) pw.setYRange(y0 - pad, y1 + pad, padding=0.0) - if finite_feature.size: - f0 = float(np.nanmin(finite_feature)) - f1 = float(np.nanmax(finite_feature)) - pad = max((f1 - f0) * 0.08, 1e-9) - vb.setYRange(f0 - pad, f1 + pad, padding=0.0) + # Right viewbox range from whichever overlay is shown there. + right_arrays: List[np.ndarray] = [] + if show_contrib: + right_arrays.append(contribution[np.isfinite(contribution)]) + if show_raw and raw_predictor is not None: + right_arrays.append(raw_predictor[np.isfinite(raw_predictor)]) + if right_arrays: + stacked = np.concatenate([np.asarray(a, float) for a in right_arrays if a.size]) + if stacked.size: + f0 = float(np.nanmin(stacked)) + f1 = float(np.nanmax(stacked)) + pad = max((f1 - f0) * 0.08, 1e-9) + vb.setYRange(f0 - pad, f1 + pad, padding=0.0) self._update_illustration_view() finite_x = x[np.isfinite(x)] From bf2321dab0bcd1433248fa9bc9f9f5fb408a854a Mon Sep 17 00:00:00 2001 From: andrianj Date: Sun, 10 May 2026 19:04:56 +0200 Subject: [PATCH 14/14] Add onboarding and temporal workflow polish --- pyBer/gui_postprocessing.py | 34 ++ pyBer/main.py | 301 ++++++++++++ pyBer/onboarding.py | 910 ++++++++++++++++++++++++++++++++++++ pyBer/temporal_modeling.py | 51 ++ 4 files changed, 1296 insertions(+) create mode 100644 pyBer/onboarding.py diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py index 91995e7..1ca5039 100644 --- a/pyBer/gui_postprocessing.py +++ b/pyBer/gui_postprocessing.py @@ -3565,10 +3565,44 @@ def _on_visual_mode_changed(self, index: int) -> None: is_individual = index == 0 self.combo_individual_file.setVisible(is_individual) self._rerender_visual_from_cache() + # Mirror to the Temporal Modeling scope: Individual -> Active file, Group -> Per-file batch. + try: + section = getattr(self, "section_temporal", None) + if section is not None and hasattr(section, "combo_fit_scope"): + target = "active" if is_individual else "batch" + idx = section.combo_fit_scope.findData(target) + if idx >= 0 and section.combo_fit_scope.currentIndex() != idx: + section.combo_fit_scope.blockSignals(True) + section.combo_fit_scope.setCurrentIndex(idx) + section.combo_fit_scope.blockSignals(False) + section._fit_mode = target + section._update_fit_state_label() + except Exception: + pass def _on_individual_file_changed(self, _index: int = 0) -> None: if self.tab_visual_mode.currentIndex() == 0: self._rerender_visual_from_cache() + # Push the new active file into the Temporal Modeling section so its + # scope strip stays in sync with the global Individual file picker. + try: + file_id = self.combo_individual_file.currentText().strip() + section = getattr(self, "section_temporal", None) + if section is None or not file_id: + return + combo = getattr(section, "combo_active_file", None) + if combo is not None: + idx = combo.findData(file_id) + if idx < 0: + idx = combo.findText(file_id) + if idx >= 0 and combo.currentIndex() != idx: + combo.blockSignals(True) + combo.setCurrentIndex(idx) + combo.blockSignals(False) + section._active_file_id = file_id + section._on_active_file_changed() + except Exception: + pass def _rerender_visual_from_cache(self) -> None: visual_mode = self.tab_visual_mode.currentIndex() diff --git a/pyBer/main.py b/pyBer/main.py index cdfaa12..fe65d20 100644 --- a/pyBer/main.py +++ b/pyBer/main.py @@ -113,6 +113,17 @@ def _is_user_site_path(path: str) -> bool: ) from gui_postprocessing import PostProcessingPanel from numeric_controls import install_spinbox_scrubbers +from onboarding import ( + ToastManager, + TutorialOverlay, + PreferencesDialog, + register_global_shortcuts, + attach_dirty_title, + install_close_confirmation, + reset_focused_plot_view, + build_default_tutorial, + add_empty_state_hint, +) from styles import ( apply_app_palette, app_qss, @@ -630,6 +641,35 @@ def _build_ui(self) -> None: self._status_bar.addPermanentWidget(QtWidgets.QLabel("App theme")) self._status_bar.addPermanentWidget(self.btn_app_theme) + # Busy / cancel indicator (left of theme widgets). Hidden until something runs. + self._busy_widget = QtWidgets.QFrame() + self._busy_widget.setObjectName("pyberBusyWidget") + self._busy_widget.setStyleSheet( + "QFrame#pyberBusyWidget { background: #2a3045; border: 1px solid #46527a;" + " border-radius: 6px; padding: 0 8px; }" + "QFrame#pyberBusyWidget QLabel { background: transparent; color: #d7e0ee; }" + "QFrame#pyberBusyWidget QPushButton { background: #543035; color: #ffd6dc;" + " border: 1px solid #8a3949; border-radius: 4px; padding: 1px 8px; }" + "QFrame#pyberBusyWidget QPushButton:hover { background: #6b3a40; }" + ) + bl = QtWidgets.QHBoxLayout(self._busy_widget) + bl.setContentsMargins(6, 1, 6, 1) + bl.setSpacing(8) + self._busy_label = QtWidgets.QLabel("Busy...") + bl.addWidget(self._busy_label) + self._busy_cancel = QtWidgets.QPushButton("Cancel") + self._busy_cancel.setToolTip("Cancel the running batch operation (Esc).") + self._busy_cancel.clicked.connect(self._cancel_current_operation) + bl.addWidget(self._busy_cancel) + self._busy_widget.setVisible(False) + self._status_bar.addPermanentWidget(self._busy_widget) + + # Poll the temporal panel's progress bar to surface batch state in status bar. + self._busy_poll = QtCore.QTimer(self) + self._busy_poll.setInterval(250) + self._busy_poll.timeout.connect(self._update_busy_indicator) + self._busy_poll.start() + # Preprocessing tab self.pre_tab = QtWidgets.QWidget() self.tabs.addTab(self.pre_tab, "Preprocessing") @@ -962,6 +1002,41 @@ def _build_ui(self) -> None: self._update_plot_status() self.setAcceptDrops(True) + # ----- UX polish: toasts, dirty-title, global shortcuts, tutorial ----- + try: + self._toaster = ToastManager(self, max_visible=4) + except Exception: + self._toaster = None + + # Mirror status-bar messages to toasts (longer-lived, easier to spot). + try: + self.post_tab.statusUpdate.connect(self._toast_from_status) + except Exception: + pass + + # Dirty-title indicator: '*' suffix while postprocessing has unsaved changes. + def _is_dirty() -> bool: + try: + return bool(getattr(self.post_tab, "_project_dirty", False)) + except Exception: + return False + + self._refresh_dirty_title = attach_dirty_title( + self, "Pyber - Fiber Photometry", _is_dirty, + ) + self._dirty_poll = QtCore.QTimer(self) + self._dirty_poll.setInterval(800) + self._dirty_poll.timeout.connect(self._refresh_dirty_title) + self._dirty_poll.start() + + install_close_confirmation(self, _is_dirty, save_callback=self._save_post_project_for_close) + + # Register the global shortcut bundle. Methods that don't exist become no-ops. + register_global_shortcuts(self) + + # First-run tutorial. + QtCore.QTimer.singleShot(450, self._maybe_show_first_run_tutorial) + def _setup_section_popups(self) -> None: """Create preprocessing section panels using DockArea or legacy floating docks.""" if self._use_pg_dockarea_pre_layout and self._pre_dockarea_docks: @@ -2027,6 +2102,232 @@ def _init_shortcuts(self) -> None: self._bind_shortcut("S", self._assign_pending_box_to_section, require_non_text_focus=True) self._bind_shortcut("Escape", self._close_focused_popup, require_non_text_focus=True) + # ---------------------------------------------------------------------- + # UX polish: tutorial / toasts / global shortcut callbacks + # ---------------------------------------------------------------------- + + def _toast_from_status(self, message: str, timeout_ms: int = 0) -> None: + if not getattr(self, "_toaster", None) or not message: + return + text = str(message) + lower = text.lower() + sev = "info" + if "fail" in lower or "error" in lower or "could not" in lower: + sev = "error" + elif "warn" in lower or "dropped" in lower or "skipped" in lower: + sev = "warn" + elif "complete" in lower or "saved" in lower or "loaded" in lower or " ok" in lower: + sev = "ok" + timeout = int(timeout_ms) if timeout_ms and timeout_ms > 0 else int(self.settings.value("ui/toast_timeout_ms", 5000) or 5000) + self._toaster.post(text, sev, timeout) + + def _maybe_show_first_run_tutorial(self) -> None: + try: + seen = self.settings.value("onboarding/first_run_completed", False) + show_pref = self.settings.value("onboarding/show_on_startup", True) + from onboarding import _to_bool + if _to_bool(seen, False) and not _to_bool(show_pref, True): + return + except Exception: + pass + self._show_tutorial_again() + + def _show_tutorial_again(self) -> None: + try: + steps = build_default_tutorial(self) + overlay = TutorialOverlay(self, steps) + overlay.finished.connect(lambda: self.settings.setValue("onboarding/first_run_completed", True)) + overlay.start() + except Exception: + pass + + def _show_keyboard_cheatsheet(self) -> None: + # Open the Preferences dialog directly on the Keyboard tab. + try: + dlg = PreferencesDialog(self, self.settings) + for i in range(dlg.findChild(QtWidgets.QTabWidget).count()): + tabs = dlg.findChild(QtWidgets.QTabWidget) + if tabs.tabText(i).lower().startswith("keyboard"): + tabs.setCurrentIndex(i) + break + dlg.exec() + except Exception: + pass + + def _open_preferences(self) -> None: + try: + dlg = PreferencesDialog(self, self.settings) + if dlg.exec() == QtWidgets.QDialog.DialogCode.Accepted: + # Apply theme right away if it changed. + desired = str(self.settings.value("app/theme", "dark") or "dark").lower() + current = getattr(self, "_app_theme_mode", "dark") + if desired != current: + if desired == "light": + self.act_app_theme_light.setChecked(True) + else: + self.act_app_theme_dark.setChecked(True) + self._on_app_theme_changed() + if getattr(self, "_toaster", None): + self._toaster.ok("Preferences saved.") + except Exception: + pass + + # --- tab navigation --- + + def _focus_pre_tab(self) -> None: + try: + self.tabs.setCurrentIndex(0) + except Exception: + pass + + def _focus_post_tab(self) -> None: + try: + self.tabs.setCurrentIndex(1) + except Exception: + pass + + def _cycle_main_tab(self) -> None: + try: + n = self.tabs.count() + if n <= 1: + return + self.tabs.setCurrentIndex((self.tabs.currentIndex() + 1) % n) + except Exception: + pass + + # --- file navigation in postprocessing/temporal --- + + def _step_active_file_next(self) -> None: + self._step_active_file(+1) + + def _step_active_file_prev(self) -> None: + self._step_active_file(-1) + + def _step_active_file(self, delta: int) -> None: + # Try the temporal panel's combo (covers the GLM scope strip). + try: + section = getattr(self.post_tab, "section_temporal", None) + if section is not None and hasattr(section, "_step_active_file"): + section._step_active_file(delta) + return + except Exception: + pass + # Fall back to postprocessing's own file combo, if any. + try: + combo = getattr(self.post_tab, "combo_individual_file", None) + if combo is not None: + idx = max(0, min(combo.count() - 1, combo.currentIndex() + int(delta))) + combo.setCurrentIndex(idx) + except Exception: + pass + + def _toggle_individual_group(self) -> None: + try: + bar = getattr(self.post_tab, "tab_visual_mode", None) + if bar is None: + return + cur = bar.currentIndex() + bar.setCurrentIndex((cur + 1) % bar.count()) + except Exception: + pass + + # --- temporal modeling --- + + def _fit_temporal_model(self) -> None: + try: + section = getattr(self.post_tab, "section_temporal", None) + if section is None: + return + self.tabs.setCurrentIndex(1) + section._on_fit_clicked() + except Exception: + pass + + def _fit_temporal_all_files(self) -> None: + try: + section = getattr(self.post_tab, "section_temporal", None) + if section is None: + return + self.tabs.setCurrentIndex(1) + section._on_fit_all_files_clicked() + except Exception: + pass + + def _recompute_psth(self) -> None: + try: + self.tabs.setCurrentIndex(1) + fn = getattr(self.post_tab, "_compute_psth", None) + if callable(fn): + fn() + except Exception: + pass + + def _run_postprocess_export(self) -> None: + try: + self.tabs.setCurrentIndex(1) + for name in ("_run_export", "_export_current", "_on_run_export_clicked", "run_export"): + fn = getattr(self.post_tab, name, None) + if callable(fn): + fn() + return + except Exception: + pass + + def _cancel_current_operation(self) -> None: + # Temporal modeling batch + try: + section = getattr(self.post_tab, "section_temporal", None) + if section is not None: + section._batch_cancel_requested = True + except Exception: + pass + # Preprocessing has its own Esc handling for popups; let it through too. + try: + self._close_focused_popup() + except Exception: + pass + + def _reset_focused_plot_view(self) -> None: + try: + reset_focused_plot_view(self) + if getattr(self, "_toaster", None): + self._toaster.info("Reset plot view.", timeout_ms=1800) + except Exception: + pass + + def _update_busy_indicator(self) -> None: + """Reflect any running batch op (currently: Temporal Modeling) in the status bar.""" + try: + section = getattr(self.post_tab, "section_temporal", None) + if section is None or not hasattr(section, "progress_model"): + self._busy_widget.setVisible(False) + return + progress = section.progress_model + if progress.isVisible(): + fmt = progress.format() or "Running..." + # Strip the trailing "%p%" placeholder for our compact label. + label = fmt.replace("%p%", "").strip(" :") + self._busy_label.setText(label or "Running...") + self._busy_widget.setVisible(True) + else: + self._busy_widget.setVisible(False) + except Exception: + self._busy_widget.setVisible(False) + + def _save_post_project_for_close(self) -> bool: + """ + Used by the close-confirmation handler. Returns True on success. + """ + try: + fn = getattr(self.post_tab, "_save_project_dialog", None) or getattr(self.post_tab, "_save_project", None) + if callable(fn): + fn() + return not bool(getattr(self.post_tab, "_project_dirty", False)) + except Exception: + pass + # No save handler available: let the user decide via Discard/Cancel. + return False + def _dock_area_from_settings( self, value: object, diff --git a/pyBer/onboarding.py b/pyBer/onboarding.py new file mode 100644 index 0000000..37a18fc --- /dev/null +++ b/pyBer/onboarding.py @@ -0,0 +1,910 @@ +# onboarding.py +""" +User-experience polish layer for pyBer: + +* TutorialOverlay - first-run guided walkthrough with highlight cutout, callouts, + Next/Back/Skip controls. Reopen via F1 / Help menu. +* ToastManager - lightweight non-blocking notifications (info/warn/error) + stacked top-right, click to dismiss. +* PreferencesDialog- consolidates theme, autosave, default kernel window, behavior + defaults, and shows the keyboard cheat sheet. +* register_shortcuts(window) - installs a wide set of global keyboard shortcuts. +* attach_dirty_title(window) - shows "*" in window title while project is dirty. + +This module is intentionally self-contained: it touches no analysis code +and never raises if the host window lacks an optional attribute. +""" +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from PySide6 import QtCore, QtGui, QtWidgets + + +# ============================================================================ +# Toast notifications +# ============================================================================ + +_TOAST_QSS = { + "info": "background: #1f2a3a; color: #e9f0fb; border: 1px solid #355080;", + "warn": "background: #3a2d1d; color: #fde6c8; border: 1px solid #8a6a3a;", + "error": "background: #3b1f25; color: #ffd6dc; border: 1px solid #8a3949;", + "ok": "background: #1c2e22; color: #d4f4dc; border: 1px solid #2f7a4a;", +} + + +class _Toast(QtWidgets.QFrame): + closed = QtCore.Signal(object) + + def __init__(self, parent: QtWidgets.QWidget, text: str, severity: str, timeout_ms: int): + super().__init__(parent) + self.setObjectName("pyberToast") + self.setAttribute(QtCore.Qt.WidgetAttribute.WA_DeleteOnClose, True) + self.setStyleSheet( + "QFrame#pyberToast { %s border-radius: 8px; padding: 8px 12px; } " + "QFrame#pyberToast QLabel { background: transparent; }" + % _TOAST_QSS.get(severity, _TOAST_QSS["info"]) + ) + lay = QtWidgets.QHBoxLayout(self) + lay.setContentsMargins(10, 7, 10, 7) + lay.setSpacing(8) + icon = {"info": "i", "warn": "!", "error": "x", "ok": "v"}.get(severity, "i") + badge = QtWidgets.QLabel(icon) + badge.setFixedSize(20, 20) + badge.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter) + badge.setStyleSheet("background: rgba(255,255,255,0.08); border-radius: 10px; font-weight: 700;") + lay.addWidget(badge) + self._label = QtWidgets.QLabel(str(text)) + self._label.setWordWrap(True) + self._label.setMaximumWidth(360) + lay.addWidget(self._label, 1) + self.setMinimumWidth(220) + self.setMaximumWidth(420) + self.setCursor(QtCore.Qt.CursorShape.PointingHandCursor) + if timeout_ms > 0: + QtCore.QTimer.singleShot(int(timeout_ms), self._dismiss) + + def mousePressEvent(self, event: QtGui.QMouseEvent) -> None: + self._dismiss() + super().mousePressEvent(event) + + def _dismiss(self) -> None: + try: + self.closed.emit(self) + finally: + self.close() + + +class ToastManager(QtCore.QObject): + """ + Stacked top-right toast queue that follows the host window. Up to + `max_visible` toasts; older ones drop off when more arrive. + """ + + def __init__(self, window: QtWidgets.QMainWindow, max_visible: int = 4): + super().__init__(window) + self._window = window + self._max_visible = int(max_visible) + self._toasts: List[_Toast] = [] + self._margin = 14 + self._spacing = 8 + window.installEventFilter(self) + + def post(self, text: str, severity: str = "info", timeout_ms: int = 5000) -> None: + if not text: + return + toast = _Toast(self._window, text, severity, timeout_ms) + toast.closed.connect(self._on_closed) + self._toasts.append(toast) + # Cap visible count. + while len(self._toasts) > self._max_visible: + old = self._toasts.pop(0) + old.close() + toast.show() + toast.adjustSize() + self._reflow() + + # Convenience wrappers. + def info(self, text: str, timeout_ms: int = 4000) -> None: + self.post(text, "info", timeout_ms) + + def ok(self, text: str, timeout_ms: int = 3500) -> None: + self.post(text, "ok", timeout_ms) + + def warn(self, text: str, timeout_ms: int = 6500) -> None: + self.post(text, "warn", timeout_ms) + + def error(self, text: str, timeout_ms: int = 9000) -> None: + self.post(text, "error", timeout_ms) + + def _on_closed(self, toast: _Toast) -> None: + try: + self._toasts.remove(toast) + except ValueError: + pass + self._reflow() + + def _reflow(self) -> None: + if not self._window.isVisible(): + return + rect = self._window.rect() + x_right = rect.right() - self._margin + y = rect.top() + self._margin + # Status bar/toolbar offset + try: + mw = self._window + if isinstance(mw, QtWidgets.QMainWindow) and mw.menuBar() is not None: + y += mw.menuBar().height() + 4 + except Exception: + pass + for toast in self._toasts: + toast.adjustSize() + w = toast.width() + toast.move(x_right - w, y) + toast.raise_() + y += toast.height() + self._spacing + + def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool: + if obj is self._window and event.type() in ( + QtCore.QEvent.Type.Resize, + QtCore.QEvent.Type.Move, + QtCore.QEvent.Type.WindowStateChange, + ): + self._reflow() + return False + + +# ============================================================================ +# Tutorial overlay +# ============================================================================ + + +class TutorialStep: + """One step of the guided tutorial.""" + + def __init__( + self, + title: str, + body: str, + target_resolver: Optional[Callable[[QtWidgets.QWidget], Optional[QtWidgets.QWidget]]] = None, + before: Optional[Callable[[QtWidgets.QWidget], None]] = None, + ): + self.title = title + self.body = body + self.target_resolver = target_resolver + self.before = before # called before showing the step (e.g. switch tab) + + +class TutorialOverlay(QtWidgets.QWidget): + """ + Full-window translucent overlay. A "spotlight" cutout illuminates the + target widget, and a styled callout near it shows step text and + Back / Next / Skip controls. Esc skips. Arrow keys page through. + """ + + finished = QtCore.Signal() + + def __init__(self, host: QtWidgets.QMainWindow, steps: List[TutorialStep]): + super().__init__(host) + self._host = host + self._steps = list(steps) + self._index = 0 + self.setAttribute(QtCore.Qt.WidgetAttribute.WA_TransparentForMouseEvents, False) + self.setAttribute(QtCore.Qt.WidgetAttribute.WA_StyledBackground, True) + self.setAttribute(QtCore.Qt.WidgetAttribute.WA_NoSystemBackground, True) + self.setFocusPolicy(QtCore.Qt.FocusPolicy.StrongFocus) + self.setMouseTracking(True) + self._target_rect = QtCore.QRect() + self._build_callout() + host.installEventFilter(self) + + # --- internal UI --- + + def _build_callout(self) -> None: + self._callout = QtWidgets.QFrame(self) + self._callout.setObjectName("tutorialCallout") + self._callout.setStyleSheet( + "QFrame#tutorialCallout { background: #14202f; color: #e9f0fb; " + "border: 1px solid #2f8cff; border-radius: 12px; }" + "QFrame#tutorialCallout QLabel { background: transparent; color: #e9f0fb; }" + "QFrame#tutorialCallout QLabel#tutTitle { font-weight: 700; font-size: 12pt; }" + "QFrame#tutorialCallout QLabel#tutStep { color: #95a5c2; font-size: 8.5pt; }" + "QFrame#tutorialCallout QPushButton { background: #1c2a3e; color: #e9f0fb; " + "border: 1px solid #355080; border-radius: 6px; padding: 5px 12px; }" + "QFrame#tutorialCallout QPushButton:hover { background: #233553; }" + "QFrame#tutorialCallout QPushButton#tutPrimary { background: #2f8cff; " + "border: 1px solid #46a0ff; color: white; font-weight: 700; }" + "QFrame#tutorialCallout QPushButton#tutPrimary:hover { background: #46a0ff; }" + "QFrame#tutorialCallout QPushButton#tutSkip { color: #95a5c2; border-color: transparent; }" + ) + lay = QtWidgets.QVBoxLayout(self._callout) + lay.setContentsMargins(18, 16, 18, 14) + lay.setSpacing(8) + + self._lbl_step = QtWidgets.QLabel("Step 1 / N") + self._lbl_step.setObjectName("tutStep") + lay.addWidget(self._lbl_step) + self._lbl_title = QtWidgets.QLabel("Title") + self._lbl_title.setObjectName("tutTitle") + self._lbl_title.setWordWrap(True) + lay.addWidget(self._lbl_title) + self._lbl_body = QtWidgets.QLabel("Body") + self._lbl_body.setWordWrap(True) + self._lbl_body.setMinimumWidth(340) + self._lbl_body.setMaximumWidth(420) + lay.addWidget(self._lbl_body) + + btn_row = QtWidgets.QHBoxLayout() + btn_row.setSpacing(6) + self._btn_skip = QtWidgets.QPushButton("Skip tutorial") + self._btn_skip.setObjectName("tutSkip") + self._btn_skip.clicked.connect(self._end) + btn_row.addWidget(self._btn_skip) + btn_row.addStretch(1) + self._btn_back = QtWidgets.QPushButton("◀ Back") + self._btn_back.clicked.connect(self._prev) + btn_row.addWidget(self._btn_back) + self._btn_next = QtWidgets.QPushButton("Next ▶") + self._btn_next.setObjectName("tutPrimary") + self._btn_next.setDefault(True) + self._btn_next.clicked.connect(self._next) + btn_row.addWidget(self._btn_next) + lay.addLayout(btn_row) + self._callout.adjustSize() + + # --- lifecycle --- + + def start(self) -> None: + self.setGeometry(self._host.rect()) + self.show() + self.raise_() + self.setFocus(QtCore.Qt.FocusReason.OtherFocusReason) + self._render_step() + + def _end(self) -> None: + self._host.removeEventFilter(self) + self.finished.emit() + self.close() + self.deleteLater() + + def _next(self) -> None: + if self._index >= len(self._steps) - 1: + self._end() + return + self._index += 1 + self._render_step() + + def _prev(self) -> None: + if self._index <= 0: + return + self._index -= 1 + self._render_step() + + # --- rendering --- + + def _render_step(self) -> None: + step = self._steps[self._index] + if step.before is not None: + try: + step.before(self._host) + except Exception: + pass + n = len(self._steps) + self._lbl_step.setText(f"Step {self._index + 1} / {n}") + self._lbl_title.setText(step.title) + self._lbl_body.setText(step.body) + self._btn_back.setEnabled(self._index > 0) + self._btn_next.setText("Got it!" if self._index == n - 1 else "Next ▶") + + target = None + if step.target_resolver is not None: + try: + target = step.target_resolver(self._host) + except Exception: + target = None + self._target_rect = self._compute_target_rect(target) + self._position_callout() + self.update() + + def _compute_target_rect(self, target: Optional[QtWidgets.QWidget]) -> QtCore.QRect: + if target is None or not target.isVisible(): + return QtCore.QRect() + # Map the target widget rect into the overlay's coordinate space. + top_left = target.mapTo(self._host, QtCore.QPoint(0, 0)) + rect = QtCore.QRect(top_left, target.size()) + rect = rect.adjusted(-6, -6, 6, 6) + return rect + + def _position_callout(self) -> None: + margin = 16 + host_rect = self.rect() + self._callout.adjustSize() + co_w = self._callout.width() + co_h = self._callout.height() + if self._target_rect.isEmpty(): + x = (host_rect.width() - co_w) // 2 + y = (host_rect.height() - co_h) // 2 + self._callout.move(x, y) + return + # Try to place to the right of the target, fallback below, then left, then above. + candidates = [ + QtCore.QPoint(self._target_rect.right() + margin, self._target_rect.top()), + QtCore.QPoint(self._target_rect.left(), self._target_rect.bottom() + margin), + QtCore.QPoint(self._target_rect.left() - co_w - margin, self._target_rect.top()), + QtCore.QPoint(self._target_rect.left(), self._target_rect.top() - co_h - margin), + ] + chosen = candidates[0] + for cand in candidates: + if ( + cand.x() >= margin and cand.y() >= margin + and cand.x() + co_w + margin <= host_rect.width() + and cand.y() + co_h + margin <= host_rect.height() + ): + chosen = cand + break + chosen.setX(max(margin, min(host_rect.width() - co_w - margin, chosen.x()))) + chosen.setY(max(margin, min(host_rect.height() - co_h - margin, chosen.y()))) + self._callout.move(chosen) + + def paintEvent(self, event: QtGui.QPaintEvent) -> None: + painter = QtGui.QPainter(self) + painter.setRenderHint(QtGui.QPainter.RenderHint.Antialiasing, True) + # Backdrop + painter.fillRect(self.rect(), QtGui.QColor(8, 12, 22, 190)) + if not self._target_rect.isEmpty(): + # Punch a soft "spotlight" on the target. + path = QtGui.QPainterPath() + path.addRect(QtCore.QRectF(self.rect())) + spot = QtGui.QPainterPath() + spot.addRoundedRect(QtCore.QRectF(self._target_rect), 8, 8) + path = path.subtracted(spot) + painter.fillPath(path, QtGui.QColor(8, 12, 22, 215)) + # Outline the target. + pen = QtGui.QPen(QtGui.QColor("#2f8cff")) + pen.setWidth(2) + painter.setPen(pen) + painter.setBrush(QtCore.Qt.BrushStyle.NoBrush) + painter.drawRoundedRect(self._target_rect, 8, 8) + # Optional: dashed glow inside. + glow = QtGui.QPen(QtGui.QColor(47, 140, 255, 120)) + glow.setWidth(1) + glow.setStyle(QtCore.Qt.PenStyle.DashLine) + painter.setPen(glow) + painter.drawRoundedRect(self._target_rect.adjusted(3, 3, -3, -3), 6, 6) + + def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: + key = event.key() + if key in (QtCore.Qt.Key.Key_Escape,): + self._end() + elif key in (QtCore.Qt.Key.Key_Right, QtCore.Qt.Key.Key_PageDown, QtCore.Qt.Key.Key_Space, QtCore.Qt.Key.Key_Return, QtCore.Qt.Key.Key_Enter): + self._next() + elif key in (QtCore.Qt.Key.Key_Left, QtCore.Qt.Key.Key_PageUp, QtCore.Qt.Key.Key_Backspace): + self._prev() + else: + super().keyPressEvent(event) + + def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool: + if obj is self._host and event.type() in ( + QtCore.QEvent.Type.Resize, + QtCore.QEvent.Type.Move, + ): + self.setGeometry(self._host.rect()) + self._render_step() + return False + + +def build_default_tutorial(window: QtWidgets.QMainWindow) -> List[TutorialStep]: + """Default first-run tutorial covering the main workflow regions.""" + + def _switch_pre(_w: QtWidgets.QMainWindow) -> None: + try: + window.tabs.setCurrentIndex(0) + except Exception: + pass + + def _switch_post(_w: QtWidgets.QMainWindow) -> None: + try: + window.tabs.setCurrentIndex(1) + except Exception: + pass + + def _resolve_tabs(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]: + return getattr(window, "tabs", None) + + def _resolve_pre_files(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]: + return getattr(window, "file_panel", None) + + def _resolve_status(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]: + return getattr(window, "_status_bar", None) + + def _resolve_post(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]: + return getattr(window, "post_tab", None) + + def _resolve_temporal(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]: + post = getattr(window, "post_tab", None) + if post is None: + return None + return getattr(post, "section_temporal", None) + + return [ + TutorialStep( + "Welcome to pyBer", + "pyBer takes raw fiber-photometry recordings (Doric / CSV) all the way to " + "PSTH, behavior alignment, and GLM/FLMM analysis.\n\n" + "Use ◀ ▶ or arrow keys to step through. Press F1 anytime to reopen this tour.", + target_resolver=None, + ), + TutorialStep( + "Two main tabs: Preprocessing -> Postprocessing", + "Preprocessing handles raw signal cleanup, artifact removal and export to " + "processed traces. Postprocessing consumes those traces for PSTH, peak/event " + "metrics, behavior alignment, and modeling. The flow is left to right.", + target_resolver=_resolve_tabs, + before=_switch_pre, + ), + TutorialStep( + "Step 1 - Drop your files here", + "Drag .doric, .csv or .h5 files (or a whole folder) onto the file queue. " + "Use Ctrl+O to browse, Ctrl+Shift+O for a folder, Delete to remove a selection.", + target_resolver=_resolve_pre_files, + before=_switch_pre, + ), + TutorialStep( + "Step 2 - Run quality check + export", + "Ctrl+Q runs QC on the active recording, Ctrl+Shift+Q does a batch QC, " + "Ctrl+E exports the current selection. Ctrl+Z / Ctrl+Y undo/redo " + "preprocessing actions.", + target_resolver=_resolve_status, + before=_switch_pre, + ), + TutorialStep( + "Step 3 - Switch to Postprocessing", + "Once you've exported processed traces, hop to the Postprocessing tab. " + "Drag the .csv/.h5 outputs onto its file dropzone, or open a saved project.", + target_resolver=_resolve_post, + before=_switch_post, + ), + TutorialStep( + "Active file vs Group", + "Postprocessing shows a single recording in 'Individual' mode or a group of " + "animals in 'Group' mode. Use Ctrl+G to toggle. Ctrl+Left / Ctrl+Right step " + "through loaded animals.", + target_resolver=_resolve_post, + before=_switch_post, + ), + TutorialStep( + "Temporal Modeling", + "The 'T' panel fits a Continuous GLM or trial-level FLMM. Choose a Scope:\n" + "- Active file (single)\n" + "- All loaded (concatenated)\n" + "- Per-file batch + group\n\n" + "Press Ctrl+Shift+F to fit. The Group tab aggregates per-file kernels " + "with mean +/- SEM across animals.", + target_resolver=_resolve_temporal, + before=_switch_post, + ), + TutorialStep( + "Keyboard shortcuts", + "Press Ctrl+/ anytime for the full cheat sheet. Highlights:\n" + "- F1 Help / replay tour Ctrl+, Preferences\n" + "- Ctrl+S Save project Ctrl+Shift+S Save as\n" + "- Ctrl+1 Preprocessing Ctrl+2 Postprocessing\n" + "- Ctrl+0 Reset focused plot view\n" + "- Esc Cancel current operation", + target_resolver=None, + ), + TutorialStep( + "You're set", + "That's the whirlwind tour. Tooltips on every input fill in the rest. " + "If something fails the toast in the corner shows what happened - click " + "it to dismiss.\n\nHappy analyzing.", + target_resolver=None, + ), + ] + + +# ============================================================================ +# Preferences dialog +# ============================================================================ + + +class PreferencesDialog(QtWidgets.QDialog): + """ + Compact preferences dialog. Reads/writes via QSettings so the values + survive restarts. Apply emits no signal; consumers (theme button, etc.) + pick up changes the next time they read the corresponding key. + """ + + KEYS = { + "theme": "app/theme", # "dark" | "light" + "autosave": "app/autosave_enabled", # bool + "autosave_min": "app/autosave_minutes", # int + "kernel_pre": "temporal_modeling/kernel_pre", + "kernel_post": "temporal_modeling/kernel_post", + "show_tutorial": "onboarding/show_on_startup", + "toast_timeout": "ui/toast_timeout_ms", + } + + def __init__(self, parent: QtWidgets.QWidget, settings: QtCore.QSettings): + super().__init__(parent) + self.setWindowTitle("Preferences") + self.setModal(True) + self._settings = settings + self.resize(540, 420) + + tabs = QtWidgets.QTabWidget(self) + + # ---- Appearance ---- + appearance = QtWidgets.QWidget() + a = QtWidgets.QFormLayout(appearance) + a.setContentsMargins(16, 16, 16, 16) + self.combo_theme = QtWidgets.QComboBox() + self.combo_theme.addItem("Dark", "dark") + self.combo_theme.addItem("Light", "light") + a.addRow("Theme", self.combo_theme) + self.spin_toast = QtWidgets.QSpinBox() + self.spin_toast.setRange(1500, 30000) + self.spin_toast.setSingleStep(500) + self.spin_toast.setSuffix(" ms") + a.addRow("Toast default duration", self.spin_toast) + self.chk_show_tutorial = QtWidgets.QCheckBox( + "Show first-run tutorial on next launch" + ) + a.addRow("Onboarding", self.chk_show_tutorial) + tabs.addTab(appearance, "Appearance") + + # ---- Defaults ---- + defaults = QtWidgets.QWidget() + d = QtWidgets.QFormLayout(defaults) + d.setContentsMargins(16, 16, 16, 16) + self.spin_kernel_pre = QtWidgets.QDoubleSpinBox() + self.spin_kernel_pre.setRange(-30.0, 0.0) + self.spin_kernel_pre.setDecimals(2) + self.spin_kernel_pre.setSuffix(" s") + d.addRow("Default GLM kernel pre", self.spin_kernel_pre) + self.spin_kernel_post = QtWidgets.QDoubleSpinBox() + self.spin_kernel_post.setRange(0.1, 60.0) + self.spin_kernel_post.setDecimals(2) + self.spin_kernel_post.setSuffix(" s") + d.addRow("Default GLM kernel post", self.spin_kernel_post) + self.chk_autosave = QtWidgets.QCheckBox("Enable project autosave") + d.addRow("Autosave", self.chk_autosave) + self.spin_autosave = QtWidgets.QSpinBox() + self.spin_autosave.setRange(1, 60) + self.spin_autosave.setSuffix(" min") + d.addRow("Autosave interval", self.spin_autosave) + tabs.addTab(defaults, "Defaults") + + # ---- Keyboard cheat sheet ---- + keyboard = QtWidgets.QWidget() + k = QtWidgets.QVBoxLayout(keyboard) + k.setContentsMargins(16, 16, 16, 16) + cheat = QtWidgets.QTextBrowser() + cheat.setReadOnly(True) + cheat.setOpenExternalLinks(False) + cheat.setHtml(_keyboard_cheatsheet_html()) + k.addWidget(cheat, 1) + tabs.addTab(keyboard, "Keyboard") + + # ---- About ---- + about = QtWidgets.QWidget() + ab = QtWidgets.QVBoxLayout(about) + ab.setContentsMargins(16, 16, 16, 16) + about_lbl = QtWidgets.QLabel( + "

pyBer - Fiber Photometry

" + "

Pipeline for raw photometry preprocessing, PSTH, behavior alignment, " + "and GLM/FLMM modeling.

" + "

Bellone Lab toolkit.

" + ) + about_lbl.setWordWrap(True) + ab.addWidget(about_lbl) + ab.addStretch(1) + tabs.addTab(about, "About") + + # ---- Buttons ---- + button_box = QtWidgets.QDialogButtonBox( + QtWidgets.QDialogButtonBox.StandardButton.Ok + | QtWidgets.QDialogButtonBox.StandardButton.Cancel + | QtWidgets.QDialogButtonBox.StandardButton.Apply + ) + button_box.accepted.connect(self._on_accept) + button_box.rejected.connect(self.reject) + button_box.button(QtWidgets.QDialogButtonBox.StandardButton.Apply).clicked.connect(self._apply) + + layout = QtWidgets.QVBoxLayout(self) + layout.setContentsMargins(10, 10, 10, 10) + layout.addWidget(tabs) + layout.addWidget(button_box) + + self._load() + + # --- internals --- + + def _load(self) -> None: + s = self._settings + theme = str(s.value(self.KEYS["theme"], "dark") or "dark").lower() + idx = self.combo_theme.findData(theme) + if idx >= 0: + self.combo_theme.setCurrentIndex(idx) + self.spin_toast.setValue(int(s.value(self.KEYS["toast_timeout"], 5000) or 5000)) + self.chk_show_tutorial.setChecked(_to_bool(s.value(self.KEYS["show_tutorial"], True), True)) + self.spin_kernel_pre.setValue(float(s.value(self.KEYS["kernel_pre"], -1.0) or -1.0)) + self.spin_kernel_post.setValue(float(s.value(self.KEYS["kernel_post"], 3.0) or 3.0)) + self.chk_autosave.setChecked(_to_bool(s.value(self.KEYS["autosave"], True), True)) + self.spin_autosave.setValue(int(s.value(self.KEYS["autosave_min"], 5) or 5)) + + def _apply(self) -> None: + s = self._settings + s.setValue(self.KEYS["theme"], self.combo_theme.currentData()) + s.setValue(self.KEYS["toast_timeout"], int(self.spin_toast.value())) + s.setValue(self.KEYS["show_tutorial"], bool(self.chk_show_tutorial.isChecked())) + s.setValue(self.KEYS["kernel_pre"], float(self.spin_kernel_pre.value())) + s.setValue(self.KEYS["kernel_post"], float(self.spin_kernel_post.value())) + s.setValue(self.KEYS["autosave"], bool(self.chk_autosave.isChecked())) + s.setValue(self.KEYS["autosave_min"], int(self.spin_autosave.value())) + + def _on_accept(self) -> None: + self._apply() + self.accept() + + +def _to_bool(value: object, default: bool = False) -> bool: + if isinstance(value, bool): + return value + text = str(value or "").strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + return default + + +# ============================================================================ +# Keyboard shortcuts +# ============================================================================ + + +_GLOBAL_SHORTCUTS: List[Tuple[str, str, str]] = [ + # (sequence, callback_attr_or_method, description) + ("F1", "_show_tutorial_again", "Show / replay the tutorial"), + ("Ctrl+/", "_show_keyboard_cheatsheet", "Open the keyboard cheat sheet"), + ("Ctrl+,", "_open_preferences", "Open Preferences"), + ("Ctrl+1", "_focus_pre_tab", "Switch to Preprocessing tab"), + ("Ctrl+2", "_focus_post_tab", "Switch to Postprocessing tab"), + ("Ctrl+Tab", "_cycle_main_tab", "Cycle main tabs"), + ("Ctrl+0", "_reset_focused_plot_view", "Reset focused plot view"), + ("Ctrl+Right", "_step_active_file_next", "Next loaded file"), + ("Ctrl+Left", "_step_active_file_prev", "Previous loaded file"), + ("Ctrl+G", "_toggle_individual_group", "Toggle Individual / Group"), + ("Ctrl+Shift+F", "_fit_temporal_model", "Fit temporal model"), + ("Ctrl+Shift+B", "_fit_temporal_all_files", "Fit GLM on every file (batch)"), + ("F5", "_recompute_psth", "Recompute PSTH"), + ("Ctrl+Shift+E", "_run_postprocess_export", "Run Postprocessing export"), + ("Esc", "_cancel_current_operation", "Cancel current operation"), +] + + +def register_global_shortcuts(window: QtWidgets.QMainWindow) -> Dict[str, str]: + """ + Bind every shortcut in `_GLOBAL_SHORTCUTS` to `window` with application-wide + context. Resolves the callback by attribute name; if the host doesn't define + the attribute, the shortcut becomes a no-op (so partial wiring is safe). + + Returns a {sequence: description} dict so the cheat sheet can stay accurate. + """ + descriptions: Dict[str, str] = {} + for seq, attr, desc in _GLOBAL_SHORTCUTS: + descriptions[seq] = desc + sc = QtGui.QShortcut(QtGui.QKeySequence(seq), window) + sc.setContext(QtCore.Qt.ShortcutContext.ApplicationShortcut) + + def _make_handler(attr_name: str) -> Callable[[], None]: + def _handler() -> None: + # Don't fire when typing in a line edit / spin box. + w = QtWidgets.QApplication.focusWidget() + if isinstance(w, (QtWidgets.QLineEdit, QtWidgets.QPlainTextEdit, QtWidgets.QTextEdit)): + return + # QSpinBox / QDoubleSpinBox: allow Esc, Ctrl+Right, etc. to pass. + fn = getattr(window, attr_name, None) + if not callable(fn): + return + try: + fn() + except Exception: + pass + + return _handler + + sc.activated.connect(_make_handler(attr)) + # Keep a reference so it isn't GC'd. + if not hasattr(window, "_pyber_global_shortcuts"): + window._pyber_global_shortcuts = [] + window._pyber_global_shortcuts.append(sc) + return descriptions + + +def _keyboard_cheatsheet_html() -> str: + rows = [ + ("Application", []), + (None, [ + ("F1", "Help / replay tutorial"), + ("Ctrl+/", "Open this cheat sheet"), + ("Ctrl+,", "Preferences"), + ("Ctrl+1 / Ctrl+2", "Switch to Preprocessing / Postprocessing"), + ("Ctrl+Tab", "Cycle main tabs"), + ("Esc", "Cancel current operation"), + ]), + ("Files / project", []), + (None, [ + ("Ctrl+O", "Open files (Preprocessing)"), + ("Ctrl+Shift+O", "Open folder (Preprocessing)"), + ("Ctrl+S", "Save project / config"), + ("Ctrl+Shift+S", "Save as"), + ("Ctrl+L", "Load preprocessing config"), + ("Delete", "Remove selected files"), + ("Ctrl+Right / Ctrl+Left", "Next / previous loaded file"), + ("Ctrl+G", "Toggle Individual / Group"), + ]), + ("Preprocessing", []), + (None, [ + ("Ctrl+Q", "Run QC on active file"), + ("Ctrl+Shift+Q", "Batch QC"), + ("Ctrl+E", "Export current selection"), + ("Ctrl+K", "Toggle Artifacts panel"), + ("Ctrl+F", "Toggle Filtering panel"), + ("Ctrl+B", "Toggle Baseline panel"), + ("Ctrl+M", "Toggle Output panel"), + ("Ctrl+D", "Toggle Data panel"), + ("Ctrl+P", "Toggle parameter popups"), + ("Ctrl+Enter", "Trigger preview"), + ("A / C / S", "Assign pending box -> Artifact / Cut / Section"), + ]), + ("Postprocessing / Modeling", []), + (None, [ + ("F5", "Recompute PSTH"), + ("Ctrl+Shift+E", "Run postprocessing export"), + ("Ctrl+Shift+F", "Fit temporal model (current scope)"), + ("Ctrl+Shift+B", "Fit GLM on every file (batch)"), + ("Ctrl+0", "Reset focused plot view"), + ]), + ] + parts = [""] + for header, items in rows: + if header is not None: + parts.append(f"

{header}

") + if items: + parts.append("") + for keys, desc in items: + key_html = " ".join(f"{k.strip()}" for k in keys.split("/")) + parts.append(f"") + parts.append("
{key_html}{desc}
") + return "".join(parts) + + +# ============================================================================ +# Window helpers +# ============================================================================ + + +def attach_dirty_title( + window: QtWidgets.QMainWindow, + base_title: str, + is_dirty_callback: Callable[[], bool], +) -> Callable[[], None]: + """ + Returns a `refresh()` function. Call it whenever the dirty state may have + changed and the title bar will gain/lose its trailing '*'. + """ + + def refresh() -> None: + try: + dirty = bool(is_dirty_callback()) + except Exception: + dirty = False + suffix = " *" if dirty else "" + window.setWindowTitle(f"{base_title}{suffix}") + + refresh() + return refresh + + +def install_close_confirmation( + window: QtWidgets.QMainWindow, + is_dirty_callback: Callable[[], bool], + save_callback: Optional[Callable[[], bool]] = None, +) -> None: + """ + Wraps the window's closeEvent to prompt when there is unsaved work. + `save_callback` (if provided) should perform the save and return True + on success. + """ + original = window.closeEvent + + def closeEvent(event: QtGui.QCloseEvent) -> None: + try: + dirty = bool(is_dirty_callback()) + except Exception: + dirty = False + if not dirty: + original(event) + return + choice = QtWidgets.QMessageBox.question( + window, + "Unsaved changes", + "You have unsaved postprocessing changes. Save before exiting?", + QtWidgets.QMessageBox.StandardButton.Save + | QtWidgets.QMessageBox.StandardButton.Discard + | QtWidgets.QMessageBox.StandardButton.Cancel, + QtWidgets.QMessageBox.StandardButton.Save, + ) + if choice == QtWidgets.QMessageBox.StandardButton.Cancel: + event.ignore() + return + if choice == QtWidgets.QMessageBox.StandardButton.Save and save_callback is not None: + try: + if not save_callback(): + event.ignore() + return + except Exception: + event.ignore() + return + original(event) + + window.closeEvent = closeEvent # type: ignore[assignment] + + +# ============================================================================ +# Plot helpers +# ============================================================================ + + +def reset_focused_plot_view(window: QtWidgets.QWidget) -> None: + """ + Walk up from the focus widget looking for a pyqtgraph PlotWidget; + if found, autorange. + """ + try: + import pyqtgraph as pg + except Exception: + return + candidate = QtWidgets.QApplication.focusWidget() or window + while candidate is not None: + if isinstance(candidate, pg.PlotWidget): + try: + candidate.getPlotItem().enableAutoRange() + except Exception: + pass + return + candidate = candidate.parent() if hasattr(candidate, "parent") else None + # Fallback: autorange all visible PlotWidgets on the window. + for pw in window.findChildren(pg.PlotWidget): + try: + pw.getPlotItem().enableAutoRange() + except Exception: + pass + + +def add_empty_state_hint( + plot, # type: ignore[no-untyped-def] pg.PlotWidget + text: str, + color: str = "#7d8aa1", +) -> Optional[Any]: + """ + Add a non-interactive TextItem at (0,0) on the plot with the given hint + text. Returns the item so callers can hide() / setVisible(False) once + real data is plotted. Safe no-op if pyqtgraph isn't importable. + """ + try: + import pyqtgraph as pg + except Exception: + return None + try: + item = pg.TextItem(str(text), color=color, anchor=(0.5, 0.5)) + item.setZValue(100) + plot.addItem(item) + item.setPos(0, 0) + return item + except Exception: + return None diff --git a/pyBer/temporal_modeling.py b/pyBer/temporal_modeling.py index 30df44d..a78bec3 100644 --- a/pyBer/temporal_modeling.py +++ b/pyBer/temporal_modeling.py @@ -1334,6 +1334,7 @@ def _build_compact_ui(self): self._build_files_page() self._build_fit_page() self._build_workspace_pages() + self._install_empty_state_hints() self.btn_nav_model.setChecked(True) self.btn_nav_model.clicked.connect(lambda: self._select_control_page(0)) @@ -1714,6 +1715,39 @@ def _select_control_page(self, index: int) -> None: for i, btn in enumerate(buttons): btn.setChecked(i == index) + def _install_empty_state_hints(self) -> None: + """Attach hint TextItems to the workspace plots; cleared on first fit.""" + self._empty_state_items = [] + plots_and_text = [ + (getattr(self, "plot_kernel", None), "No GLM fit yet.\nPick predictors and press Fit (Ctrl+Shift+F)."), + (getattr(self, "plot_prediction", None), "Predicted vs actual will appear after fitting."), + (getattr(self, "plot_residuals", None), "Residuals appear after a successful fit."), + (getattr(self, "plot_importance", None), "Leave-one-out feature contribution.\nFit a model first."), + (getattr(self, "plot_illustration", None), "Signal + selected feature contribution.\nFit, then pick a feature."), + (getattr(self, "plot_coeff", None), "FLMM coefficient curves appear after a trial-level fit."), + (getattr(self, "plot_group_kernels", None), "Group view\n\nRun a Per-file batch fit to populate this tab."), + (getattr(self, "plot_group_importance", None), "Group leave-one-out contribution\n\nRun a Per-file batch fit to populate this tab."), + ] + for plot, text in plots_and_text: + if plot is None: + continue + try: + item = pg.TextItem(text, color="#6f7d95", anchor=(0.5, 0.5)) + item.setZValue(100) + plot.addItem(item) + item.setPos(0, 0) + self._empty_state_items.append(item) + except Exception: + pass + + def _clear_empty_state_hints(self) -> None: + for item in getattr(self, "_empty_state_items", []) or []: + try: + item.setVisible(False) + except Exception: + pass + self._empty_state_items = [] + def _style_plot(self, plot: pg.PlotWidget) -> None: plot.setMinimumHeight(360) plot.setBackground("#05080d") @@ -3438,6 +3472,23 @@ def _on_active_file_changed(self, *_): self.list_files.setCurrentRow(i) self.list_files.blockSignals(False) break + # Push selection back into the host postprocessing panel so the + # PSTH file picker and the Temporal scope file always agree. + try: + host = self.parent() + while host is not None and not hasattr(host, "combo_individual_file"): + host = host.parent() if hasattr(host, "parent") else None + combo = getattr(host, "combo_individual_file", None) if host is not None else None + if combo is not None: + idx = combo.findText(data) + if idx >= 0 and combo.currentIndex() != idx: + combo.blockSignals(True) + combo.setCurrentIndex(idx) + combo.blockSignals(False) + if hasattr(host, "_rerender_visual_from_cache"): + host._rerender_visual_from_cache() + except Exception: + pass # If we are in Active scope and a fit is cached for this file, render it. if self._fit_mode == "active": cached = self._glm_results_by_file.get(self._active_file_id)