diff --git a/README.md b/README.md
index 92a0a7f..725f4f2 100644
--- a/README.md
+++ b/README.md
@@ -1,202 +1,92 @@
-# Fiber Photometry Processing GUI
-
-
-
-
-A desktop GUI for visualizing photometry recordings, cleaning artifacts, filtering/resampling, baseline estimation, motion-correction, and exporting processed traces for downstream analysis.
-
-This project is designed for efficient exploratory QC (preview in the GUI) while keeping processing logic deterministic and scriptable (core functions live in `analysis_core.py`).
-
----
-
-## Key Features
-
-### Data IO
-- Supports raw data in .doric, .h5 or .csv.
-- Multi-channel support (e.g., analogue, DIO channels …).
-- Optional alignment of analog traces to the DigitalIO timebase when a DIO is selected.
-
-### Artifact Handling
-- Artifact detection on raw 465 using derivative thresholding (`dx`) and MAD:
- - **Global MAD (dx)**: one threshold for the full trace
- - **Adaptive MAD (windowed)**: windowed thresholds for nonstationary noise
-- Optional **padding** around detected artifacts to remove spillover.
-- Manual artifact masking by user-defined time regions.
-- Masked samples are replaced via **linear interpolation** to preserve time alignment.
-
-### Signal Conditioning
-- Low-pass filtering.
-- Decimation/resampling.
-
-### Baseline Estimation
-Baseline is computed after filtering and resampling using **pybaselines**:
-- `asls`, `arpls`, `airpls`
-- tunable parameters (lambda, diff order, iterations, tolerance)
-
-### Output Modes (7)
-The GUI exposes seven explicit output definitions:
-
-1. **dFF (non motion corrected)**
- `dFF = (signal_filtered - signal_baseline) / signal_baseline`
-
-2. **zscore (non motion corrected)**
- `zscore(dFF_nonMC)`
-
-3. **dFF (motion corrected via subtraction)**
- `dFF_mc = dFF_signal - dFF_ref`
- where each dFF uses its own baseline.
-
-4. **zscore (motion corrected via subtraction)**
- `zscore(dFF_signal - dFF_ref)`
-
-5. **zscore (subtractions)**
- `zscore(dFF_signal) - zscore(dFF_ref)`
-
-6. **dFF (motion corrected with fitted ref)**
- Fit the isosbestic/reference channel to the signal:
- `fitted_ref = a * ref_filtered + b`
- then compute:
- `dFF = (signal_filtered - fitted_ref) / fitted_ref`
-
-7. **zscore (motion corrected with fitted ref)**
- `zscore( (signal_filtered - fitted_ref) / fitted_ref )`
-
-### Reference Fitting Methods (for “fitted ref” modes)
-- **OLS (recommended)**: fast and stable
-- **Lasso**: sparse regression (requires `scikit-learn`)
-- **RLM (HuberT)**: robust linear model via IRLS + Huber weighting (no extra dependency)
-
-### Export
-- Export processed output to:
- - CSV with configurable fields (`time` always included; raw/isobestic/output/DIO selectable)
- - HDF5 with configurable raw/output/DIO/baseline datasets plus metadata
-- Export field selection is saved and restored through the preprocessing configuration file.
-- Drag-and-drop support for preprocessing and post-processing files.
-
----
-
-## Repository Structure (typical)
-
-- `analysis_core.py`
- Processing pipeline (loading, filtering, baselines, outputs, export helpers)
-- `main.py` (or similar)
- PySide6 GUI entry point and UI wiring
-- `requirements.yml`
- Conda environment definition
-
----
-
-## Installation
-
-1. Create the environment:
- ```bash
- conda env create -f environment.yml
-## Run
- cd .\pyBer\
- python main.py
-
-## Usage workflow
-
-Open a Doric .h5 file
-
-Choose a channel (e.g., AIN01).
-
-Optionally select a DigitalIO line to overlay events.
-
-### QC & artifact removal
-
-Choose Global MAD (dx) or Adaptive MAD (windowed).
-
-Tune mad_k, window size, and padding.
-
-Add manual mask regions if needed.
-
-### Filtering & resampling
-
-Set low-pass cutoff (Hz) and filter order.
-
-Set a target sampling rate (Hz) for consistent downstream analysis.
-
-### Baseline estimation
-
-Choose asls, arpls, or airpls.
-
-Tune lambda and other parameters to avoid baseline leakage into fast transients.
-
-### Select output
-
-Pick one of the 7 output modes.
-
-For “fitted ref” modes, choose the fit method (OLS/Lasso/RLM-HuberT).
-
-
-### Export
-
-Export CSV/H5 for analysis in Python/MATLAB/R.
-
----
-
-## Preprocessing: Advanced Options
-
-The preprocessing panel includes an **Advanced options** button with two features:
-
-1) **Cut out regions (NaN)**
-Define start/end ranges to exclude parts of the trace from downstream analysis. Cutout regions are filled with NaN in the output and can be exported as-is.
-
-2) **Sections (per-section processing)**
-Define multiple start/end sections and **assign per-section processing parameters**. Each section can be exported independently (one CSV/H5 per section).
-
-### Time window
-You can optionally set a start time, end time, or both:
-- Start only: process from start → end of recording
-- End only: process from 0 → end
-- Start + end: process that window only
-- Empty: process full trace
-
----
-
-## Post-Processing
-
-### Align sources
-- **DIO**: choose DIO channel, polarity (0→1 or 1→0), and align to onset/offset
-- **Behavior (CSV/XLSX)**: load a behavior file and select a column (binary 0/1)
- - Align to onset or offset
- - **Transitions**: align to A→B transitions with a max gap threshold
-
-### Behavior file formats
-- **CSV**: must include a time column and one or more binary behavior columns
-- **Ethovision XLSX**: the loader preprocess and clean the file, the user has to select the sheet when loading
-
-### PSTH/Heatmap
-- Heatmap and PSTH refresh automatically when alignment settings change
-- Event lines are overlaid on the trace preview
-- Event duration histogram appears to the right of the heatmap
-- Metrics bar plot (pre vs post) appears to the right of the PSTH
-
-### Metrics
-Choose AUC or mean z-score and define pre/post windows in seconds.
-
-### Export results
-Export any combination of:
-- Heatmap matrix
-- Average PSTH + SEM
-- Event times
-- Event durations
-- Metrics table
-
----
-
-## Group Mode (Multiple Animals)
-
-Use the **Group** tab in Post-Processing to load multiple processed files (CSV/H5). Each file should represent a single animal, and matching behavior files should share the same base name. In group mode:
-- Heatmap rows represent animals (not trials)
-- PSTH is averaged across animals
-
----
-
-
-
-
-
-
-
+# pyBer
+
+pyBer is a desktop application for fiber photometry analysis. It helps you load
+Doric, HDF5, or CSV recordings, clean artifacts, preprocess traces, align signals
+to behavior or DIO events, inspect PSTHs and heatmaps, detect transients, and
+export results for Python, MATLAB, R, or Prism.
+
+The app is built for users who want an interactive workflow first, with
+deterministic processing code underneath.
+
+
+
+## Quick Install
+
+Install Miniforge or Anaconda first, then run:
+
+```powershell
+cd C:\Analysis\app_project\pyBer
+conda env create -f environment.yml
+conda activate pyBer
+Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+python .\pyBer\main.py
+```
+
+The `fastFMM` step is only needed for the FLMM temporal modeling panel. The rest
+of pyBer works without it.
+
+## Launch From VS Code
+
+1. Open the repository folder in VS Code.
+2. Select the interpreter from the `pyBer` conda environment.
+3. Open `pyBer/main.py`.
+4. Press Run, or use:
+
+```powershell
+conda activate pyBer
+python .\pyBer\main.py
+```
+
+If VS Code launches the wrong Python, run `Python: Select Interpreter` and choose
+the environment created from `environment.yml`.
+
+## What You Can Do
+
+- Preprocess raw photometry traces with filtering, resampling, baseline
+ correction, motion correction, and artifact handling.
+- Detect and inspect artifacts with interpolation, cutout, local low-pass
+ filtering, or no-op handling.
+- Export processed CSV or HDF5 files with selectable fields and metadata.
+- Align processed signals to DIO, behavior states, behavior onsets, or behavior
+ transitions.
+- Detect signal events and compare transient amplitude with baseline-prominence
+ normalized metrics.
+- Build individual or group PSTHs, heatmaps, event duration plots, and metrics.
+- Fit temporal models with continuous GLM or trial-level FLMM.
+- Rank GLM/FLMM feature contribution with leave-one-feature-out summaries.
+
+## Documentation
+
+The full user guide is here:
+
+- [pyBer Documentation](docs/index.md)
+
+It includes installation, first launch, preprocessing, postprocessing, transient
+detection, temporal modeling, group workflows, export, and troubleshooting.
+
+## Repository Layout
+
+- `pyBer/main.py`: application entry point.
+- `pyBer/analysis_core.py`: preprocessing and signal processing backend.
+- `pyBer/gui_preprocessing.py`: preprocessing panels.
+- `pyBer/gui_postprocessing.py`: postprocessing, PSTH, metrics, and export panels.
+- `pyBer/temporal_modeling.py`: GLM and FLMM modeling panel.
+- `environment.yml`: conda environment for development and user installs.
+- `pyBer.spec`: PyInstaller build configuration.
+
+## Build The Executable
+
+From an activated environment:
+
+```powershell
+conda activate pyBer
+python -m PyInstaller --noconfirm --clean pyBer.spec
+```
+
+The executable is written to `dist/pyBer.exe`.
+
+## Notes
+
+pyBer sets `PYTHONNOUSERSITE=1` in the environment so old packages from the user
+Python folder do not interfere with the conda environment. This is important for
+Qt, pyqtgraph, numpy, and rpy2 stability on Windows.
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..3027b22
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,285 @@
+# pyBer Documentation
+
+This page is a practical guide for installing and using pyBer. It is written for
+lab users who want to process recordings without reading the code first.
+
+## 1. Install pyBer
+
+### Recommended Windows install
+
+Install Miniforge or Anaconda, open an Anaconda/Miniforge Prompt, then run:
+
+```powershell
+cd C:\Analysis\app_project\pyBer
+conda env create -f environment.yml
+conda activate pyBer
+Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+python .\pyBer\main.py
+```
+
+The `fastFMM` command installs the R package used by the FLMM temporal modeling
+panel. It can take a few minutes the first time because R downloads dependencies.
+
+### Update an existing environment
+
+If the environment already exists:
+
+```powershell
+conda activate pyBer
+conda env update -f environment.yml --prune
+Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+```
+
+### Test the install
+
+Run:
+
+```powershell
+conda activate pyBer
+python .\pyBer\main.py
+```
+
+The app should open with Preprocessing and Postprocessing tabs.
+
+## 2. Launch From VS Code
+
+1. Open the pyBer repository folder.
+2. Press `Ctrl+Shift+P`.
+3. Choose `Python: Select Interpreter`.
+4. Select the `pyBer` conda environment.
+5. Open `pyBer/main.py`.
+6. Press Run.
+
+If the Run button still fails, use the terminal:
+
+```powershell
+conda activate pyBer
+python .\pyBer\main.py
+```
+
+## 3. Preprocessing Workflow
+
+Use Preprocessing when you want to clean and export photometry traces.
+
+1. Load one or more raw files.
+2. Select the signal channel and optional reference channel.
+3. Choose artifact detection and handling.
+4. Set filtering and resampling.
+5. Choose baseline correction.
+6. Choose the output signal definition.
+7. Preview the result.
+8. Export CSV or HDF5.
+
+### Output definitions
+
+pyBer exposes explicit output modes so exported traces are reproducible:
+
+- dFF without motion correction.
+- z-score without motion correction.
+- dFF with motion correction by subtraction.
+- z-score with motion correction by subtraction.
+- z-score signal minus z-score reference.
+- dFF with fitted reference.
+- z-score with fitted reference.
+
+For fitted-reference modes, pyBer fits the reference channel to the signal before
+computing dFF. The usual choice is OLS. Lasso and robust Huber fitting are also
+available.
+
+## 4. Artifact Handling
+
+Artifact settings let you choose how masked windows are handled:
+
+- Interpolation: replace artifact samples by linear interpolation.
+- Cut: keep artifact samples as NaN so downstream analysis ignores them.
+- Strong local low-pass filtering: smooth only inside the artifact window.
+- Do nothing: detect or mark artifacts without changing the trace.
+
+Use interpolation when you need continuous traces. Use cut when the artifact
+window should not contribute to statistics.
+
+## 5. Postprocessing Workflow
+
+Use Postprocessing when you want to align processed traces to events or behavior.
+
+1. Load processed files from preprocessing.
+2. Load behavior files if needed.
+3. Choose the alignment source.
+4. Click `Compute PSTH`.
+5. Inspect the trace preview, heatmap, average PSTH, duration plot, and metrics.
+6. Export matrices, event times, metrics, and figures.
+
+### Alignment sources
+
+pyBer can align to:
+
+- DIO onset or offset.
+- Behavior onset or offset from CSV or XLSX files.
+- Binary behavior state columns.
+- Behavior transitions.
+- Signal events detected from the processed trace.
+
+### Group mode
+
+Use Group mode when each processed file represents one animal. pyBer keeps
+per-file trial matrices for temporal modeling and can also display animal-level
+group summaries.
+
+For best GLM and FLMM results, load matching behavior files whose base names
+match the processed files.
+
+## 6. Signal Event Analyzer
+
+The Signal Event Analyzer detects transients and reports metrics. Useful options:
+
+- Auto MAD noise thresholding for transient detection.
+- Min prominence, min height, min distance, and smoothing.
+- Optional detected-peak overlay on the trace.
+- Optional noise trace overlay.
+- Baseline-prominence normalized amplitude for comparing recordings.
+
+Baseline-prominence normalized amplitude is useful when recordings differ in
+baseline level or noise scale. It normalizes each detected transient relative to
+its local baseline/prominence context.
+
+## 7. Temporal Modeling
+
+The Temporal Modeling panel supports two approaches.
+
+### Continuous GLM
+
+Use GLM when you want to model the continuous photometry trace from event and
+behavior predictors.
+
+Typical predictors:
+
+- DIO events.
+- Behavior onsets.
+- Behavior states.
+- Numeric behavior columns.
+- Signal event times.
+
+The GLM output includes:
+
+- R-squared.
+- RMSE, MAE, MSE, residual SD, and actual/predicted correlation.
+- Estimated kernels for each predictor.
+- Actual vs predicted signal.
+- Residual trace.
+- Leave-one-predictor-out feature contribution.
+
+The leave-one-predictor-out ranking refits the model after removing each
+predictor. Larger `delta R^2` means that predictor explains more of the signal.
+
+### Trial-level FLMM
+
+Use FLMM when you want trial-level functional modeling with random effects.
+
+Requirements:
+
+- R installed through the conda environment or available on the system.
+- Python package `rpy2`.
+- R package `fastFMM`.
+- Repeated rows per subject or animal.
+
+The environment installs R and rpy2. Install fastFMM with:
+
+```powershell
+conda activate pyBer
+Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+```
+
+The FLMM output includes:
+
+- Fixed-effect coefficient curves.
+- Pointwise and joint confidence bands when available.
+- AIC summary.
+- Coefficient magnitude statistics.
+- Leave-one-feature-out AIC contribution when the reduced models are estimable.
+
+If a reduced FLMM cannot be estimated, pyBer still reports the coefficient-based
+contribution so the feature ranking remains usable.
+
+## 8. Export
+
+Preprocessing can export processed traces as:
+
+- CSV.
+- HDF5.
+
+Postprocessing can export:
+
+- Heatmap matrix.
+- Average PSTH and SEM.
+- Event times.
+- Event durations.
+- Metrics tables.
+- Group-level outputs.
+
+Use HDF5 when you want metadata and multiple arrays in one file. Use CSV when you
+want easy loading into spreadsheets or Prism.
+
+## 9. Troubleshooting
+
+### The app does not launch from VS Code
+
+Make sure VS Code is using the conda environment:
+
+```powershell
+conda activate pyBer
+python .\pyBer\main.py
+```
+
+If this works but the Run button fails, select the interpreter again in VS Code.
+
+### Dark mode or Qt styling looks broken
+
+This is usually a mixed Python environment. Recreate the environment and keep
+`PYTHONNOUSERSITE=1` enabled:
+
+```powershell
+conda env remove -n pyBer
+conda env create -f environment.yml
+conda activate pyBer
+```
+
+### FLMM says fastFMM is unavailable
+
+Run:
+
+```powershell
+conda activate pyBer
+Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+```
+
+Then restart pyBer.
+
+### FLMM says random effects cannot be estimated
+
+FLMM needs repeated rows per subject. In practice, each animal should have
+multiple trials. If you only provide one animal-averaged row per animal, the GLM
+panel is usually the better choice.
+
+### The heatmap looks wrong after switching Individual and Group
+
+Click `Compute PSTH` again after changing loaded files or behavior alignment.
+pyBer stores both per-file trial matrices and group matrices, but recomputing is
+the clearest way to refresh all derived views after a major setup change.
+
+## 10. Build A Windows Executable
+
+From the repository root:
+
+```powershell
+conda activate pyBer
+python -m PyInstaller --noconfirm --clean pyBer.spec
+```
+
+The app is written to:
+
+```text
+dist\pyBer.exe
+```
+
+When building with FLMM support, make sure `fastFMM` is installed before running
+PyInstaller.
diff --git a/environment.yml b/environment.yml
index bf2d4c2..895b999 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,3 +1,12 @@
+# Create with:
+# conda env create -f environment.yml
+# conda activate pyBer
+#
+# Optional FLMM backend:
+# Rscript -e "install.packages('fastFMM', repos='https://cloud.r-project.org')"
+#
+# The R package is installed after environment creation because fastFMM is a
+# CRAN package, not a conda package.
name: pyBer
channels:
- conda-forge
@@ -26,6 +35,14 @@ dependencies:
- pyqtgraph>=0.13
- matplotlib>=3.8
+ # FLMM backend via R fastFMM/rpy2.
+ - r-base>=4.4
+ - rpy2>=3.6
+
+ # Windows R/rpy2 startup uses sh/make when validating R's config.
+ - m2-base
+ - m2-make
+
# Quality-of-life / packaging
- pip
- setuptools
diff --git a/pyBer.spec b/pyBer.spec
index 4c50bed..adbfc2e 100644
--- a/pyBer.spec
+++ b/pyBer.spec
@@ -39,7 +39,12 @@ a = Analysis(
'freetype.dll',
]),
datas=[('assets/pyBer_logo_big.png', 'assets'), ('assets/pyBer.ico', 'assets')],
- hiddenimports=[],
+ hiddenimports=[
+ 'rpy2.rinterface',
+ 'rpy2.rinterface_lib',
+ 'rpy2.robjects',
+ 'rpy2.robjects.packages',
+ ],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
diff --git a/pyBer/analysis_core.py b/pyBer/analysis_core.py
index 38c883d..edea563 100644
--- a/pyBer/analysis_core.py
+++ b/pyBer/analysis_core.py
@@ -50,6 +50,13 @@
"Moving median",
]
+ARTIFACT_HANDLING_MODES = [
+ "Interpolate",
+ "Cut",
+ "Strong local low-pass",
+ "Do nothing",
+]
+
# Output modes
OUTPUT_MODES = [
# 1) dFF (non motion corrected)
@@ -80,6 +87,7 @@ class ProcessingParams:
# -------------------------
artifact_detection_enabled: bool = True
artifact_mode: str = "Global MAD (dx)" # or "Adaptive MAD (windowed)"
+ artifact_handling: str = "Interpolate"
mad_k: float = 8.0
adaptive_window_s: float = 5.0
artifact_pad_s: float = 0.25
@@ -285,6 +293,7 @@ class ProcessedTrial:
output: Optional[np.ndarray] = None
output_label: str = ""
output_context: str = ""
+ outputs: Dict[str, np.ndarray] = field(default_factory=dict)
artifact_regions_sec: Optional[List[Tuple[float, float]]] = None
artifact_regions_auto_sec: Optional[List[Tuple[float, float]]] = None
@@ -302,8 +311,9 @@ class ExportSelection:
dio: bool = True
baseline_sig: bool = True
baseline_ref: bool = True
+ output_modes: List[str] = field(default_factory=list)
- def to_dict(self) -> Dict[str, bool]:
+ def to_dict(self) -> Dict[str, Any]:
return {
"raw": bool(self.raw),
"isobestic": bool(self.isobestic),
@@ -311,12 +321,16 @@ def to_dict(self) -> Dict[str, bool]:
"dio": bool(self.dio),
"baseline_sig": bool(self.baseline_sig),
"baseline_ref": bool(self.baseline_ref),
+ "output_modes": list(self.output_modes or []),
}
@classmethod
def from_dict(cls, data: Optional[Dict[str, object]]) -> "ExportSelection":
if not isinstance(data, dict):
return cls()
+ modes = data.get("output_modes", [])
+ if not isinstance(modes, list):
+ modes = []
return cls(
raw=bool(data.get("raw", True)),
isobestic=bool(data.get("isobestic", True)),
@@ -324,6 +338,7 @@ def from_dict(cls, data: Optional[Dict[str, object]]) -> "ExportSelection":
dio=bool(data.get("dio", True)),
baseline_sig=bool(data.get("baseline_sig", True)),
baseline_ref=bool(data.get("baseline_ref", True)),
+ output_modes=[str(m).strip() for m in modes if str(m or "").strip()],
)
@@ -367,6 +382,67 @@ def output_label_type(label: str) -> str:
return "output"
+def output_label_key(label: str) -> str:
+ """Return a stable, readable key for one output mode in multi-output exports."""
+ lab = (label or "").strip()
+ key = re.sub(r"[^A-Za-z0-9]+", "_", lab).strip("_").lower()
+ replacements = {
+ "dff_motion_corrected_with_fitted_ref": "dff_mc_fitted_ref",
+ "zscore_motion_corrected_with_fitted_ref": "zscore_mc_fitted_ref",
+ "prominence_normalized_motion_corrected_with_fitted_ref": "prominence_mc_fitted_ref",
+ "dff_motion_corrected_via_subtraction": "dff_mc_subtraction",
+ "zscore_motion_corrected_via_subtraction": "zscore_mc_subtraction",
+ "dff_non_motion_corrected": "dff_non_mc",
+ "zscore_non_motion_corrected": "zscore_non_mc",
+ "zscore_subtractions": "zscore_subtractions",
+ "raw_signal_465": "raw_signal_465",
+ }
+ return replacements.get(key, key or "output")
+
+
+def _unique_export_name(name: str, used: set) -> str:
+ base = str(name or "output").strip() or "output"
+ out = base
+ i = 2
+ while out in used:
+ out = f"{base}_{i}"
+ i += 1
+ used.add(out)
+ return out
+
+
+def _output_items_for_export(
+ processed: ProcessedTrial,
+ selection: ExportSelection,
+) -> List[Tuple[str, np.ndarray]]:
+ """Return output traces in requested export order, falling back to the selected output."""
+ source = getattr(processed, "outputs", None) or {}
+ items: List[Tuple[str, np.ndarray]] = []
+
+ if source:
+ requested = [str(m).strip() for m in (selection.output_modes or []) if str(m or "").strip()]
+ seen = set()
+ for label in requested:
+ if label in source and label not in seen:
+ items.append((label, np.asarray(source[label], float)))
+ seen.add(label)
+ for label, values in source.items():
+ if label not in seen:
+ items.append((str(label), np.asarray(values, float)))
+ seen.add(label)
+
+ if not items and processed.output is not None:
+ label = str(processed.output_label or "output")
+ items.append((label, np.asarray(processed.output, float)))
+
+ return items
+
+
+def _csv_output_values(label: str, values: np.ndarray, t: np.ndarray) -> np.ndarray:
+ values = values if values.size == t.size else np.full_like(t, np.nan)
+ return np.asarray(values, float)
+
+
def export_processed_csv(
path: str,
processed: ProcessedTrial,
@@ -389,10 +465,15 @@ def export_processed_csv(
if selection.dio and processed.dio is not None and processed.dio.size == t.size:
dio = np.asarray(processed.dio, float)
- out_col = output_label_type(processed.output_label)
+ output_items = _output_items_for_export(processed, selection) if selection.output else []
with open(path, "w", newline="") as f:
w = csv.writer(f)
+ w.writerow([f"# output_label: {processed.output_label}"])
+ if processed.output_context:
+ w.writerow([f"# output_context: {processed.output_context}"])
+ if output_items:
+ w.writerow([f"# output_modes: {json.dumps([label for label, _ in output_items])}"])
if metadata:
for k, v in metadata.items():
w.writerow([f"# {k}: {v}"])
@@ -401,8 +482,20 @@ def export_processed_csv(
columns.append(("raw", raw))
if selection.isobestic:
columns.append(("isobestic", iso))
- if selection.output:
- columns.append((out_col, out))
+ if selection.output and output_items:
+ if len(output_items) == 1:
+ label, values = output_items[0]
+ values = _csv_output_values(label, values, t)
+ columns.append((output_label_type(label), values))
+ else:
+ used_names = {name for name, _ in columns}
+ primary_label, primary_values = output_items[0]
+ primary_values = _csv_output_values(primary_label, primary_values, t)
+ columns.append((_unique_export_name("output", used_names), primary_values))
+ for label, values in output_items:
+ values = _csv_output_values(label, values, t)
+ col = _unique_export_name(f"output__{output_label_key(label)}", used_names)
+ columns.append((col, values))
# Primary trigger name (e.g. "DIO01")
primary_name = str(processed.dio_name) if processed.dio_name else "dio"
@@ -435,16 +528,32 @@ def export_processed_h5(
g = f.create_group("data")
g.create_dataset("time", data=np.asarray(processed.time, float), compression="gzip")
out_type = output_label_type(processed.output_label)
+ output_items = _output_items_for_export(processed, selection) if selection.output else []
g.attrs["output_label"] = str(processed.output_label)
g.attrs["output_context"] = str(processed.output_context)
g.attrs["output_type"] = str(out_type)
+ if output_items:
+ g.attrs["output_modes"] = json.dumps([label for label, _ in output_items])
g.attrs["fs_actual"] = float(processed.fs_actual)
g.attrs["fs_used"] = float(processed.fs_used)
g.attrs["fs_target"] = float(processed.fs_target)
g.attrs["export_selection"] = json.dumps(selection.to_dict())
- if selection.output:
- g.create_dataset("output", data=np.asarray(processed.output, float), compression="gzip")
+ if selection.output and output_items:
+ t = np.asarray(processed.time, float)
+ _primary_label, primary_output = output_items[0]
+ primary_output = primary_output if primary_output.size == t.size else np.full_like(t, np.nan)
+ g.create_dataset("output", data=np.asarray(primary_output, float), compression="gzip")
+
+ if len(output_items) > 1:
+ out_group = g.create_group("outputs")
+ used = set()
+ for label, values in output_items:
+ values = values if values.size == t.size else np.full_like(t, np.nan)
+ ds_name = _unique_export_name(output_label_key(label), used)
+ ds = out_group.create_dataset(ds_name, data=np.asarray(values, float), compression="gzip")
+ ds.attrs["label"] = str(label)
+ ds.attrs["output_type"] = str(output_label_type(label))
raw_sig = np.asarray(processed.raw_signal if processed.raw_signal is not None else np.full_like(processed.time, np.nan), float)
raw_ref = np.asarray(processed.raw_reference if processed.raw_reference is not None else np.full_like(processed.time, np.nan), float)
@@ -569,6 +678,82 @@ def _lowpass_sos(x: np.ndarray, fs: float, cutoff: float, order: int) -> np.ndar
return np.asarray(sosfiltfilt(sos, y), float)
+def _normalize_artifact_handling(value: object) -> str:
+ text = str(value or "").strip().lower()
+ if text.startswith("cut"):
+ return "Cut"
+ if "low" in text and "pass" in text:
+ return "Strong local low-pass"
+ if text.startswith("do") or text in {"none", "nothing", "off", "ignore"}:
+ return "Do nothing"
+ return "Interpolate"
+
+
+def _strong_local_lowpass_artifacts(
+ x: np.ndarray,
+ mask: np.ndarray,
+ fs: float,
+ base_cutoff_hz: float,
+ filter_order: int,
+) -> np.ndarray:
+ y = np.asarray(x, float).copy()
+ m = np.asarray(mask, bool)
+ if y.size == 0 or m.size != y.size or not np.any(m):
+ return y
+ bridged = y.copy()
+ bridged[m] = np.nan
+ bridged = interpolate_nans(bridged)
+ if np.any(~np.isfinite(bridged)):
+ return y
+
+ # Use a clearly stronger local cutoff than the main anti-aliasing filter.
+ try:
+ base_cutoff = float(base_cutoff_hz)
+ except Exception:
+ base_cutoff = 12.0
+ if not np.isfinite(base_cutoff) or base_cutoff <= 0:
+ base_cutoff = 12.0
+ cutoff = min(2.0, max(0.05, 0.25 * base_cutoff))
+ order = int(max(3, min(6, int(filter_order) + 1)))
+ try:
+ replacement = _lowpass_sos(bridged, fs, cutoff, order)
+ except Exception:
+ replacement = bridged
+ y[m] = replacement[m]
+ return y
+
+
+def _apply_artifact_handling(
+ sig: np.ndarray,
+ ref: np.ndarray,
+ mask: np.ndarray,
+ fs: float,
+ params: ProcessingParams,
+) -> Tuple[np.ndarray, np.ndarray, str]:
+ handling = _normalize_artifact_handling(getattr(params, "artifact_handling", "Interpolate"))
+ sig_corr = np.asarray(sig, float).copy()
+ ref_corr = np.asarray(ref, float).copy()
+ m = np.asarray(mask, bool)
+ if sig_corr.size == 0 or ref_corr.size == 0 or m.size != sig_corr.size or not np.any(m):
+ return sig_corr, ref_corr, handling
+
+ if handling == "Do nothing" or handling == "Cut":
+ return sig_corr, ref_corr, handling
+
+ if handling == "Strong local low-pass":
+ cutoff = float(getattr(params, "lowpass_hz", 12.0))
+ order = int(getattr(params, "filter_order", 3))
+ return (
+ _strong_local_lowpass_artifacts(sig_corr, m, fs, cutoff, order),
+ _strong_local_lowpass_artifacts(ref_corr, m, fs, cutoff, order),
+ handling,
+ )
+
+ sig_corr[m] = np.nan
+ ref_corr[m] = np.nan
+ return interpolate_nans(sig_corr), interpolate_nans(ref_corr), handling
+
+
def _window_samples_from_seconds(
fs: float,
window_s: float,
@@ -1379,14 +1564,11 @@ def process_trial(
mask = apply_manual_regions(t, mask, manual_regions_sec or [])
# ---------------------------------------------------------------------
- # 4) Mask artifacts (set NaN) then interpolate (keeps timebase intact)
+ # 4) Apply selected artifact handling.
+ # Interpolate is the historical default and keeps the timebase intact.
+ # Cut is applied after resampling so all processed arrays stay aligned.
# ---------------------------------------------------------------------
- sig_corr = sig.copy()
- ref_corr = ref.copy()
- sig_corr[mask] = np.nan
- ref_corr[mask] = np.nan
- sig_corr = interpolate_nans(sig_corr)
- ref_corr = interpolate_nans(ref_corr)
+ sig_corr, ref_corr, artifact_handling = _apply_artifact_handling(sig, ref, mask, fs, params)
# ---------------------------------------------------------------------
# 5) Low-pass filter before decimation (anti-aliasing)
@@ -1409,6 +1591,16 @@ def process_trial(
# Resample the envelope for display (same timebase as processed)
_, hi2, lo2, _ = _resample_pair_to_target_fs(t, hi_raw, lo_raw, fs, target_fs)
+ if artifact_handling == "Cut" and np.any(mask):
+ mask2 = np.interp(t2, t, mask.astype(float), left=0.0, right=0.0) > 0.5
+ keep = ~mask2
+ if np.sum(keep) >= 3:
+ t2 = t2[keep]
+ sig2 = sig2[keep]
+ ref2 = ref2[keep]
+ hi2 = hi2[keep]
+ lo2 = lo2[keep]
+
# ---------------------------------------------------------------------
# 7) A/D overlay (if present): interpolate and binarize
# ---------------------------------------------------------------------
@@ -1511,8 +1703,10 @@ def process_trial(
out = dff_sig
context_parts = []
+ if np.any(mask):
+ context_parts.append(f"Artifacts: {artifact_handling}")
if mode == "Raw signal (465)":
- context_parts.append("Raw 465 after artifact interpolation, filtering, and resampling")
+ context_parts.append("Raw 465 after artifact handling, filtering, and resampling")
else:
baseline_desc = f"Baseline: {params.baseline_method} (lambda={float(params.baseline_lambda):.2e})"
if mode in (
@@ -1567,6 +1761,7 @@ def process_trial(
output=out,
output_label=mode,
output_context=output_context,
+ outputs={mode: np.asarray(out, float)} if out is not None else {},
artifact_regions_sec=regions_from_mask(t, mask),
artifact_regions_auto_sec=auto_regions,
fs_actual=float(fs),
diff --git a/pyBer/ethovision_process_gui.py b/pyBer/ethovision_process_gui.py
index af66169..8a139db 100644
--- a/pyBer/ethovision_process_gui.py
+++ b/pyBer/ethovision_process_gui.py
@@ -32,6 +32,8 @@
import pyqtgraph as pg
+from analysis_core import coerce_time_value
+
MISSING_MARKERS = {"", " ", "-", "NaN", "nan", "NAN", "n/a", "N/A", None}
@@ -208,7 +210,6 @@ def clean_sheet(
time_orig = df[time_col].copy()
df[time_col] = pd.to_numeric(df[time_col], errors="coerce")
if df[time_col].isna().all():
- from pyBer.analysis_core import coerce_time_value
df[time_col] = time_orig.astype(str).apply(coerce_time_value)
df = df.loc[df[time_col].notna()].copy()
df = df.sort_values(time_col).reset_index(drop=True)
diff --git a/pyBer/gui_postprocessing.py b/pyBer/gui_postprocessing.py
index 15f78f5..1ca5039 100644
--- a/pyBer/gui_postprocessing.py
+++ b/pyBer/gui_postprocessing.py
@@ -4,6 +4,7 @@
import os
import re
import json
+import copy
import logging
from pathlib import Path
from dataclasses import dataclass
@@ -16,8 +17,9 @@
from pyqtgraph.dockarea import DockArea, Dock
import h5py
-from analysis_core import ProcessedTrial
+from analysis_core import ProcessedTrial, coerce_time_value
from ethovision_process_gui import clean_sheet
+from temporal_modeling import TemporalModelingWidget
_DOCK_STATE_VERSION = 3
_POST_DOCK_STATE_KEY = "post_main_dock_state_v4"
@@ -28,14 +30,15 @@
_PRE_DOCK_PREFIX = "pre."
_BEHAVIOR_PARSE_BINARY = "binary_columns"
_BEHAVIOR_PARSE_TIMESTAMPS = "timestamp_columns"
-_FIXED_POST_RIGHT_SECTIONS = frozenset({"setup", "spatial", "psth", "export"})
-_FIXED_POST_VISIBLE_SECTIONS = frozenset({"setup", "spatial", "psth", "export"})
-_FIXED_POST_RIGHT_TAB_ORDER = ("setup", "psth", "spatial", "export")
+_FIXED_POST_RIGHT_SECTIONS = frozenset({"setup", "spatial", "psth", "export", "temporal"})
+_FIXED_POST_VISIBLE_SECTIONS = frozenset({"setup", "spatial", "psth", "export", "temporal"})
+_FIXED_POST_RIGHT_TAB_ORDER = ("setup", "psth", "spatial", "temporal", "export")
_POST_RIGHT_PANEL_MIN_WIDTH = 420
_FIXED_POST_RIGHT_TAB_TITLES: Dict[str, str] = {
"setup": "Setup",
"psth": "PSTH",
"spatial": "Spatial",
+ "temporal": "Temporal",
"export": "Export",
}
_USE_PG_DOCKAREA_POST_LAYOUT = True
@@ -248,7 +251,6 @@ def _detect_time_column(df, fallback_to_first: bool = False) -> Optional[str]:
def _numeric_column_array(df, col_name: str) -> np.ndarray:
import pandas as pd
- from pyBer.analysis_core import coerce_time_value
col_key = None
for c in df.columns:
@@ -508,6 +510,7 @@ def __init__(self, parent=None) -> None:
self._event_labels: List[pg.TextItem] = []
self._event_regions: List[pg.LinearRegionItem] = []
self._signal_peak_lines: List[pg.InfiniteLine] = []
+ self._signal_noise_items: List[object] = []
self._pre_region: Optional[pg.LinearRegionItem] = None
self._post_region: Optional[pg.LinearRegionItem] = None
self._settings = QtCore.QSettings("FiberPhotometryApp", "DoricProcessor")
@@ -523,6 +526,7 @@ def __init__(self, parent=None) -> None:
"heatmap_cmap": "viridis",
"heatmap_min": None,
"heatmap_max": None,
+ "heatmap_levels_manual": False,
}
self._section_popups: Dict[str, QtWidgets.QDockWidget] = {}
self._section_scroll_hosts: Dict[str, QtWidgets.QScrollArea] = {}
@@ -557,6 +561,13 @@ def __init__(self, parent=None) -> None:
self._project_dirty: bool = False
self._autosave_restoring: bool = False
self._project_recovered_from_autosave: bool = False
+ self._suppress_heatmap_level_store: bool = False
+ self._history_undo: List[Dict[str, object]] = []
+ self._history_redo: List[Dict[str, object]] = []
+ self._history_current: Optional[Dict[str, object]] = None
+ self._history_key: str = ""
+ self._history_restoring: bool = False
+ self._history_limit: int = 60
try:
self._build_ui()
self._restore_settings()
@@ -570,6 +581,7 @@ def __init__(self, parent=None) -> None:
if app is not None:
app.aboutToQuit.connect(self._on_about_to_quit)
QtCore.QTimer.singleShot(0, self._restore_project_autosave_if_needed)
+ QtCore.QTimer.singleShot(0, self._reset_history_snapshot)
def _build_ui(self) -> None:
root = QtWidgets.QVBoxLayout(self)
@@ -624,6 +636,7 @@ def _build_ui(self) -> None:
vsrc.addWidget(self.btn_refresh_dio)
grp_align = QtWidgets.QGroupBox("Behavior / Events")
+ grp_align.setSizePolicy(QtWidgets.QSizePolicy.Policy.Preferred, QtWidgets.QSizePolicy.Policy.Expanding)
fal = QtWidgets.QFormLayout(grp_align)
fal.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
fal.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
@@ -685,12 +698,20 @@ def _build_ui(self) -> None:
# Preprocessed files list
self.list_preprocessed = FileDropList()
- self.list_preprocessed.setMaximumHeight(120)
+ self.list_preprocessed.setMinimumHeight(180)
+ self.list_preprocessed.setSizePolicy(
+ QtWidgets.QSizePolicy.Policy.Expanding,
+ QtWidgets.QSizePolicy.Policy.Expanding,
+ )
self.list_preprocessed.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.MultiSelection)
# Behaviors list
self.list_behaviors = FileDropList()
- self.list_behaviors.setMaximumHeight(120)
+ self.list_behaviors.setMinimumHeight(180)
+ self.list_behaviors.setSizePolicy(
+ QtWidgets.QSizePolicy.Policy.Expanding,
+ QtWidgets.QSizePolicy.Policy.Expanding,
+ )
self.list_behaviors.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.MultiSelection)
# Control buttons for ordering
@@ -724,6 +745,8 @@ def _build_ui(self) -> None:
beh_col = QtWidgets.QVBoxLayout()
beh_col.addWidget(self.list_behaviors)
beh_col.addWidget(self.btn_remove_beh)
+ pre_col.setStretch(0, 1)
+ beh_col.setStretch(0, 1)
lists_layout.addLayout(pre_col)
lists_layout.addLayout(beh_col)
@@ -764,16 +787,20 @@ def _build_ui(self) -> None:
fal.addRow(self.lbl_trans_gap, self.spin_transition_gap)
- grp_opt = QtWidgets.QGroupBox("PSTH Options")
- fopt = QtWidgets.QFormLayout(grp_opt)
- fopt.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
- fopt.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
+ # ── Shared QSS for PSTH subsection headers ──
+ _psth_section_qss = (
+ "QGroupBox { font-weight: 700; font-size: 11px; "
+ "border: 1px solid rgba(255,255,255,0.07); border-radius: 6px; "
+ "margin-top: 10px; padding: 14px 8px 8px 8px; }"
+ "QGroupBox::title { subcontrol-origin: margin; left: 10px; "
+ "padding: 0 6px; color: #8899b0; }"
+ )
+ # ── Widget creation (unchanged logic, reordered for sections) ──
self.spin_pre = QtWidgets.QDoubleSpinBox(); self.spin_pre.setRange(0.1, 60); self.spin_pre.setValue(2.0); self.spin_pre.setDecimals(2)
- self.spin_post= QtWidgets.QDoubleSpinBox(); self.spin_post.setRange(0.1, 120); self.spin_post.setValue(5.0); self.spin_post.setDecimals(2)
- self.spin_b0 = QtWidgets.QDoubleSpinBox(); self.spin_b0.setRange(-60, 0); self.spin_b0.setValue(-1.0); self.spin_b0.setDecimals(2)
- self.spin_b1 = QtWidgets.QDoubleSpinBox(); self.spin_b1.setRange(-60, 0); self.spin_b1.setValue(0.0); self.spin_b1.setDecimals(2)
-
+ self.spin_post = QtWidgets.QDoubleSpinBox(); self.spin_post.setRange(0.1, 120); self.spin_post.setValue(5.0); self.spin_post.setDecimals(2)
+ self.spin_b0 = QtWidgets.QDoubleSpinBox(); self.spin_b0.setRange(-60, 0); self.spin_b0.setValue(-1.0); self.spin_b0.setDecimals(2)
+ self.spin_b1 = QtWidgets.QDoubleSpinBox(); self.spin_b1.setRange(-60, 0); self.spin_b1.setValue(0.0); self.spin_b1.setDecimals(2)
self.spin_resample = QtWidgets.QDoubleSpinBox(); self.spin_resample.setRange(1, 1000); self.spin_resample.setValue(50); self.spin_resample.setDecimals(1)
self.spin_smooth = QtWidgets.QDoubleSpinBox(); self.spin_smooth.setRange(0, 5); self.spin_smooth.setValue(0.0); self.spin_smooth.setDecimals(2)
@@ -787,6 +814,7 @@ def _build_ui(self) -> None:
self.spin_group_window = QtWidgets.QDoubleSpinBox(); self.spin_group_window.setRange(0.0, 1e6); self.spin_group_window.setValue(0.0); self.spin_group_window.setDecimals(3)
self.spin_dur_min = QtWidgets.QDoubleSpinBox(); self.spin_dur_min.setRange(0, 1e6); self.spin_dur_min.setValue(0.0); self.spin_dur_min.setDecimals(2)
self.spin_dur_max = QtWidgets.QDoubleSpinBox(); self.spin_dur_max.setRange(0, 1e6); self.spin_dur_max.setValue(0.0); self.spin_dur_max.setDecimals(2)
+
self.cb_metrics = QtWidgets.QCheckBox("Enable PSTH metrics")
self.cb_metrics.setChecked(True)
self.btn_hide_metrics = QtWidgets.QToolButton()
@@ -800,163 +828,133 @@ def _build_ui(self) -> None:
self.spin_metric_post0 = QtWidgets.QDoubleSpinBox(); self.spin_metric_post0.setRange(0, 120); self.spin_metric_post0.setValue(0.0); self.spin_metric_post0.setDecimals(2)
self.spin_metric_post1 = QtWidgets.QDoubleSpinBox(); self.spin_metric_post1.setRange(0, 120); self.spin_metric_post1.setValue(1.0); self.spin_metric_post1.setDecimals(2)
+ self.cb_global_metrics = QtWidgets.QCheckBox("Enable global metrics")
+ self.cb_global_metrics.setChecked(True)
+ self.spin_global_start = QtWidgets.QDoubleSpinBox(); self.spin_global_start.setRange(-1e6, 1e6); self.spin_global_start.setValue(0.0); self.spin_global_start.setDecimals(2)
+ self.spin_global_end = QtWidgets.QDoubleSpinBox(); self.spin_global_end.setRange(-1e6, 1e6); self.spin_global_end.setValue(0.0); self.spin_global_end.setDecimals(2)
+ self.cb_global_amp = QtWidgets.QCheckBox("Peak amplitude")
+ self.cb_global_amp.setChecked(True)
+ self.cb_global_freq = QtWidgets.QCheckBox("Transient frequency")
+ self.cb_global_freq.setChecked(True)
+ self.lbl_global_metrics = QtWidgets.QLabel("Global metrics: -")
+ self.lbl_global_metrics.setProperty("class", "hint")
+
for w in (
self.spin_pre, self.spin_post, self.spin_b0, self.spin_b1,
self.spin_resample, self.spin_smooth,
- self.spin_event_start, self.spin_event_end, self.spin_group_window, self.spin_dur_min, self.spin_dur_max,
- self.spin_metric_pre0, self.spin_metric_pre1, self.spin_metric_post0, self.spin_metric_post1,
+ self.spin_event_start, self.spin_event_end, self.spin_group_window,
+ self.spin_dur_min, self.spin_dur_max,
+ self.spin_metric_pre0, self.spin_metric_pre1,
+ self.spin_metric_post0, self.spin_metric_post1,
+ self.spin_global_start, self.spin_global_end,
):
w.setMinimumWidth(60)
w.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
- win_row = QtWidgets.QGridLayout()
- win_row.setHorizontalSpacing(6)
- win_row.setContentsMargins(0, 0, 0, 0)
- win_pre = QtWidgets.QLabel("Pre:")
- win_post = QtWidgets.QLabel("Post:")
- win_pre.setMinimumWidth(35)
- win_post.setMinimumWidth(35)
- win_row.addWidget(win_pre, 0, 0)
- win_row.addWidget(self.spin_pre, 0, 1)
- win_row.addWidget(win_post, 0, 2)
- win_row.addWidget(self.spin_post, 0, 3)
- win_row.setColumnStretch(1, 1)
- win_row.setColumnStretch(3, 1)
- win_widget = QtWidgets.QWidget(); win_widget.setLayout(win_row)
-
- base_row = QtWidgets.QGridLayout()
- base_row.setHorizontalSpacing(6)
- base_row.setContentsMargins(0, 0, 0, 0)
- base_start = QtWidgets.QLabel("Start:")
- base_end = QtWidgets.QLabel("End:")
- base_start.setMinimumWidth(45)
- base_end.setMinimumWidth(35)
- base_row.addWidget(base_start, 0, 0)
- base_row.addWidget(self.spin_b0, 0, 1)
- base_row.addWidget(base_end, 0, 2)
- base_row.addWidget(self.spin_b1, 0, 3)
- base_row.setColumnStretch(1, 1)
- base_row.setColumnStretch(3, 1)
- base_widget = QtWidgets.QWidget(); base_widget.setLayout(base_row)
-
- metric_pre_row = QtWidgets.QGridLayout()
- metric_pre_row.setHorizontalSpacing(6)
- metric_pre_row.setContentsMargins(0, 0, 0, 0)
- metric_pre_start = QtWidgets.QLabel("Start:")
- metric_pre_end = QtWidgets.QLabel("End:")
- metric_pre_start.setMinimumWidth(45)
- metric_pre_end.setMinimumWidth(35)
- metric_pre_row.addWidget(metric_pre_start, 0, 0)
- metric_pre_row.addWidget(self.spin_metric_pre0, 0, 1)
- metric_pre_row.addWidget(metric_pre_end, 0, 2)
- metric_pre_row.addWidget(self.spin_metric_pre1, 0, 3)
- metric_pre_row.setColumnStretch(1, 1)
- metric_pre_row.setColumnStretch(3, 1)
- metric_pre_widget = QtWidgets.QWidget(); metric_pre_widget.setLayout(metric_pre_row)
-
- metric_post_row = QtWidgets.QGridLayout()
- metric_post_row.setHorizontalSpacing(6)
- metric_post_row.setContentsMargins(0, 0, 0, 0)
- metric_post_start = QtWidgets.QLabel("Start:")
- metric_post_end = QtWidgets.QLabel("End:")
- metric_post_start.setMinimumWidth(45)
- metric_post_end.setMinimumWidth(35)
- metric_post_row.addWidget(metric_post_start, 0, 0)
- metric_post_row.addWidget(self.spin_metric_post0, 0, 1)
- metric_post_row.addWidget(metric_post_end, 0, 2)
- metric_post_row.addWidget(self.spin_metric_post1, 0, 3)
- metric_post_row.setColumnStretch(1, 1)
- metric_post_row.setColumnStretch(3, 1)
- metric_post_widget = QtWidgets.QWidget(); metric_post_widget.setLayout(metric_post_row)
-
- fopt.addRow("Window (s)", win_widget)
- fopt.addRow("Baseline (s)", base_widget)
- fopt.addRow("Resample (Hz)", self.spin_resample)
+ # ── Helper: dual-spin row ──
+ def _dual_row(lbl_a: str, w_a, lbl_b: str, w_b):
+ g = QtWidgets.QGridLayout()
+ g.setHorizontalSpacing(6); g.setContentsMargins(0, 0, 0, 0)
+ la = QtWidgets.QLabel(lbl_a); la.setMinimumWidth(35)
+ lb = QtWidgets.QLabel(lbl_b); lb.setMinimumWidth(35)
+ g.addWidget(la, 0, 0); g.addWidget(w_a, 0, 1)
+ g.addWidget(lb, 0, 2); g.addWidget(w_b, 0, 3)
+ g.setColumnStretch(1, 1); g.setColumnStretch(3, 1)
+ w = QtWidgets.QWidget(); w.setLayout(g); return w
+
+ win_widget = _dual_row("Pre:", self.spin_pre, "Post:", self.spin_post)
+ base_widget = _dual_row("Start:", self.spin_b0, "End:", self.spin_b1)
+ metric_pre_widget = _dual_row("Start:", self.spin_metric_pre0, "End:", self.spin_metric_pre1)
+ metric_post_widget = _dual_row("Start:", self.spin_metric_post0, "End:", self.spin_metric_post1)
+ global_widget = _dual_row("Start:", self.spin_global_start, "End:", self.spin_global_end)
+
+ # ═══════════════════════════════════════════════════════
+ # Section 1 — Window & Baseline
+ # ═══════════════════════════════════════════════════════
+ grp_window = QtWidgets.QGroupBox("Window && baseline")
+ grp_window.setStyleSheet(_psth_section_qss)
+ fw = QtWidgets.QFormLayout(grp_window)
+ fw.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
+ fw.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
+ fw.addRow("Window (s)", win_widget)
+ fw.addRow("Baseline (s)", base_widget)
+ fw.addRow("Resample (Hz)", self.spin_resample)
+ fw.addRow("Smooth sigma (s)", self.spin_smooth)
+
+ # ═══════════════════════════════════════════════════════
+ # Section 2 — Event filters
+ # ═══════════════════════════════════════════════════════
+ grp_filt = QtWidgets.QGroupBox("Event filters")
+ grp_filt.setStyleSheet(_psth_section_qss)
+ ff = QtWidgets.QFormLayout(grp_filt)
+ ff.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
+ ff.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
filt_row = QtWidgets.QHBoxLayout()
- filt_row.setContentsMargins(0, 0, 0, 0)
- filt_row.setSpacing(6)
- filt_row.addWidget(self.cb_filter_events)
- filt_row.addStretch(1)
+ filt_row.setContentsMargins(0, 0, 0, 0); filt_row.setSpacing(6)
+ filt_row.addWidget(self.cb_filter_events); filt_row.addStretch(1)
filt_row.addWidget(self.btn_hide_filters)
filt_widget = QtWidgets.QWidget(); filt_widget.setLayout(filt_row)
- fopt.addRow(filt_widget)
- self.lbl_event_start = QtWidgets.QLabel("Event index start (1-based)")
- self.lbl_event_end = QtWidgets.QLabel("Event index end (0=all)")
- self.lbl_group_window = QtWidgets.QLabel("Group events within (s) (0=off)")
- self.lbl_dur_min = QtWidgets.QLabel("Event duration min (s)")
- self.lbl_dur_max = QtWidgets.QLabel("Event duration max (s)")
- fopt.addRow(self.lbl_event_start, self.spin_event_start)
- fopt.addRow(self.lbl_event_end, self.spin_event_end)
- fopt.addRow(self.lbl_group_window, self.spin_group_window)
- fopt.addRow(self.lbl_dur_min, self.spin_dur_min)
- fopt.addRow(self.lbl_dur_max, self.spin_dur_max)
- fopt.addRow("Gaussian smooth sigma (s)", self.spin_smooth)
+ ff.addRow(filt_widget)
+ self.lbl_event_start = QtWidgets.QLabel("Start index (1-based)")
+ self.lbl_event_end = QtWidgets.QLabel("End index (0 = all)")
+ self.lbl_group_window = QtWidgets.QLabel("Group within (s)")
+ self.lbl_dur_min = QtWidgets.QLabel("Duration min (s)")
+ self.lbl_dur_max = QtWidgets.QLabel("Duration max (s)")
+ ff.addRow(self.lbl_event_start, self.spin_event_start)
+ ff.addRow(self.lbl_event_end, self.spin_event_end)
+ ff.addRow(self.lbl_group_window, self.spin_group_window)
+ ff.addRow(self.lbl_dur_min, self.spin_dur_min)
+ ff.addRow(self.lbl_dur_max, self.spin_dur_max)
+
+ # ═══════════════════════════════════════════════════════
+ # Section 3 — PSTH metrics
+ # ═══════════════════════════════════════════════════════
+ grp_met = QtWidgets.QGroupBox("PSTH metrics")
+ grp_met.setStyleSheet(_psth_section_qss)
+ fm = QtWidgets.QFormLayout(grp_met)
+ fm.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
+ fm.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
met_row = QtWidgets.QHBoxLayout()
- met_row.setContentsMargins(0, 0, 0, 0)
- met_row.setSpacing(6)
- met_row.addWidget(self.cb_metrics)
- met_row.addStretch(1)
+ met_row.setContentsMargins(0, 0, 0, 0); met_row.setSpacing(6)
+ met_row.addWidget(self.cb_metrics); met_row.addStretch(1)
met_row.addWidget(self.btn_hide_metrics)
met_widget = QtWidgets.QWidget(); met_widget.setLayout(met_row)
- fopt.addRow(met_widget)
+ fm.addRow(met_widget)
self.lbl_metric = QtWidgets.QLabel("Metric")
- self.lbl_metric_pre = QtWidgets.QLabel("Metric pre (s)")
- self.lbl_metric_post = QtWidgets.QLabel("Metric post (s)")
- fopt.addRow(self.lbl_metric, self.combo_metric)
- fopt.addRow(self.lbl_metric_pre, metric_pre_widget)
- fopt.addRow(self.lbl_metric_post, metric_post_widget)
-
- self.cb_global_metrics = QtWidgets.QCheckBox("Enable global metrics")
- self.cb_global_metrics.setChecked(True)
- fopt.addRow(self.cb_global_metrics)
-
- self.spin_global_start = QtWidgets.QDoubleSpinBox()
- self.spin_global_start.setRange(-1e6, 1e6)
- self.spin_global_start.setValue(0.0)
- self.spin_global_start.setDecimals(2)
- self.spin_global_end = QtWidgets.QDoubleSpinBox()
- self.spin_global_end.setRange(-1e6, 1e6)
- self.spin_global_end.setValue(0.0)
- self.spin_global_end.setDecimals(2)
- self.spin_global_start.setMinimumWidth(60)
- self.spin_global_end.setMinimumWidth(60)
- self.spin_global_start.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
- self.spin_global_end.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
-
- global_row = QtWidgets.QGridLayout()
- global_row.setHorizontalSpacing(6)
- global_row.setContentsMargins(0, 0, 0, 0)
- global_row.addWidget(QtWidgets.QLabel("Start:"), 0, 0)
- global_row.addWidget(self.spin_global_start, 0, 1)
- global_row.addWidget(QtWidgets.QLabel("End:"), 0, 2)
- global_row.addWidget(self.spin_global_end, 0, 3)
- global_row.setColumnStretch(1, 1)
- global_row.setColumnStretch(3, 1)
- global_widget = QtWidgets.QWidget()
- global_widget.setLayout(global_row)
- fopt.addRow("Global range (s)", global_widget)
-
- self.cb_global_amp = QtWidgets.QCheckBox("Peak amplitude")
- self.cb_global_amp.setChecked(True)
- self.cb_global_freq = QtWidgets.QCheckBox("Transient frequency")
- self.cb_global_freq.setChecked(True)
+ self.lbl_metric_pre = QtWidgets.QLabel("Pre window (s)")
+ self.lbl_metric_post = QtWidgets.QLabel("Post window (s)")
+ fm.addRow(self.lbl_metric, self.combo_metric)
+ fm.addRow(self.lbl_metric_pre, metric_pre_widget)
+ fm.addRow(self.lbl_metric_post, metric_post_widget)
+
+ # ═══════════════════════════════════════════════════════
+ # Section 4 — Global metrics
+ # ═══════════════════════════════════════════════════════
+ grp_global = QtWidgets.QGroupBox("Global metrics")
+ grp_global.setStyleSheet(_psth_section_qss)
+ fg = QtWidgets.QFormLayout(grp_global)
+ fg.setRowWrapPolicy(QtWidgets.QFormLayout.RowWrapPolicy.WrapLongRows)
+ fg.setLabelAlignment(QtCore.Qt.AlignmentFlag.AlignLeft | QtCore.Qt.AlignmentFlag.AlignTop)
+ fg.addRow(self.cb_global_metrics)
+ fg.addRow("Range (s)", global_widget)
global_opts = QtWidgets.QHBoxLayout()
- global_opts.setContentsMargins(0, 0, 0, 0)
- global_opts.setSpacing(6)
+ global_opts.setContentsMargins(0, 0, 0, 0); global_opts.setSpacing(6)
global_opts.addWidget(self.cb_global_amp)
global_opts.addWidget(self.cb_global_freq)
global_opts.addStretch(1)
- global_opts_widget = QtWidgets.QWidget()
- global_opts_widget.setLayout(global_opts)
- fopt.addRow("Global metrics", global_opts_widget)
-
- self.lbl_global_metrics = QtWidgets.QLabel("Global metrics: -")
- self.lbl_global_metrics.setProperty("class", "hint")
- fopt.addRow("", self.lbl_global_metrics)
-
- for w in (self.spin_global_start, self.spin_global_end):
- w.setMinimumWidth(60)
- w.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
+ global_opts_widget = QtWidgets.QWidget(); global_opts_widget.setLayout(global_opts)
+ fg.addRow("Compute", global_opts_widget)
+ fg.addRow("", self.lbl_global_metrics)
+
+ # Container for all subsections (replaces old grp_opt)
+ grp_opt = QtWidgets.QWidget()
+ _psth_vbox = QtWidgets.QVBoxLayout(grp_opt)
+ _psth_vbox.setContentsMargins(0, 0, 0, 0)
+ _psth_vbox.setSpacing(4)
+ _psth_vbox.addWidget(grp_window)
+ _psth_vbox.addWidget(grp_filt)
+ _psth_vbox.addWidget(grp_met)
+ _psth_vbox.addWidget(grp_global)
self.btn_compute = QtWidgets.QPushButton("Postprocessing (compute PSTH)")
self.btn_compute.setProperty("class", "compactPrimarySmall")
@@ -979,6 +977,9 @@ def _build_ui(self) -> None:
self.btn_load_cfg = QtWidgets.QPushButton("Load config")
self.btn_load_cfg.setProperty("class", "compactSmall")
self.btn_load_cfg.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
+ self.btn_reset_cfg = QtWidgets.QPushButton("Reset defaults")
+ self.btn_reset_cfg.setProperty("class", "compactSmall")
+ self.btn_reset_cfg.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
self.btn_new_project = QtWidgets.QPushButton("New project")
self.btn_new_project.setProperty("class", "compactSmall")
self.btn_new_project.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
@@ -1007,6 +1008,16 @@ def _build_ui(self) -> None:
self.spin_peak_prominence.setRange(0.0, 1e6)
self.spin_peak_prominence.setValue(0.5)
self.spin_peak_prominence.setDecimals(4)
+ self.cb_peak_auto_mad = QtWidgets.QCheckBox("Auto transient threshold (MAD noise)")
+ self.cb_peak_auto_mad.setChecked(False)
+ self.cb_peak_auto_mad.setToolTip(
+ "Estimate trace noise as 1.4826 x MAD after baseline/smoothing and use "
+ "the multiplier below as the minimum peak prominence."
+ )
+ self.spin_peak_mad_multiplier = QtWidgets.QDoubleSpinBox()
+ self.spin_peak_mad_multiplier.setRange(0.5, 50.0)
+ self.spin_peak_mad_multiplier.setValue(5.0)
+ self.spin_peak_mad_multiplier.setDecimals(2)
self.spin_peak_height = QtWidgets.QDoubleSpinBox()
self.spin_peak_height.setRange(0.0, 1e6)
self.spin_peak_height.setValue(0.0)
@@ -1033,8 +1044,19 @@ def _build_ui(self) -> None:
self.spin_peak_auc_window.setRange(0.0, 30.0)
self.spin_peak_auc_window.setValue(0.5)
self.spin_peak_auc_window.setDecimals(3)
+ self.cb_peak_norm_prominence = QtWidgets.QCheckBox("Baseline-prominence normalized amplitude")
+ self.cb_peak_norm_prominence.setChecked(False)
+ self.cb_peak_norm_prominence.setToolTip(
+ "Report peak amplitude after scaling by the top baseline peak prominences."
+ )
self.cb_peak_overlay = QtWidgets.QCheckBox("Show detected peaks on trace")
self.cb_peak_overlay.setChecked(True)
+ self.cb_peak_noise_overlay = QtWidgets.QCheckBox("Show noise trace / MAD threshold overlay")
+ self.cb_peak_noise_overlay.setChecked(False)
+ self.cb_peak_noise_overlay.setToolTip(
+ "After peak detection, overlay the preprocessed detection trace, robust noise band, "
+ "and effective prominence threshold used for the visible file."
+ )
self.btn_detect_peaks = QtWidgets.QPushButton("Detect peaks")
self.btn_detect_peaks.setProperty("class", "compactPrimarySmall")
self.btn_export_peaks = QtWidgets.QPushButton("Export peaks CSV")
@@ -1046,14 +1068,18 @@ def _build_ui(self) -> None:
f_signal.addRow("File", self.combo_signal_file)
f_signal.addRow("Method", self.combo_signal_method)
f_signal.addRow("Min prominence", self.spin_peak_prominence)
+ f_signal.addRow(self.cb_peak_auto_mad)
+ f_signal.addRow("MAD multiplier", self.spin_peak_mad_multiplier)
f_signal.addRow("Min height (0=off)", self.spin_peak_height)
f_signal.addRow("Min distance (s)", self.spin_peak_distance)
f_signal.addRow("Smooth sigma (s)", self.spin_peak_smooth)
f_signal.addRow("Baseline handling", self.combo_peak_baseline)
f_signal.addRow("Baseline window (s)", self.spin_peak_baseline_window)
+ f_signal.addRow(self.cb_peak_norm_prominence)
f_signal.addRow("Rate bin (s)", self.spin_peak_rate_bin)
f_signal.addRow("AUC window (+/- s)", self.spin_peak_auc_window)
f_signal.addRow(self.cb_peak_overlay)
+ f_signal.addRow(self.cb_peak_noise_overlay)
signal_btn_row = QtWidgets.QHBoxLayout()
signal_btn_row.addWidget(self.btn_detect_peaks)
signal_btn_row.addWidget(self.btn_export_peaks)
@@ -1258,7 +1284,7 @@ def _build_ui(self) -> None:
setup_layout.setContentsMargins(6, 6, 6, 6)
setup_layout.setSpacing(8)
setup_layout.addWidget(grp_src)
- setup_layout.addWidget(grp_align)
+ setup_layout.addWidget(grp_align, stretch=1)
setup_btn_row = QtWidgets.QHBoxLayout()
self.btn_setup_load = QtWidgets.QPushButton("Load")
self.btn_setup_load.setProperty("class", "compactPrimarySmall")
@@ -1303,6 +1329,11 @@ def _build_ui(self) -> None:
behavior_layout.addWidget(self.lbl_behavior_summary)
behavior_layout.addStretch(1)
+ self.section_temporal = TemporalModelingWidget()
+ self.section_temporal.statusMessage.connect(
+ lambda msg, ms: self.statusUpdate.emit(msg, ms)
+ )
+
self.section_spatial = QtWidgets.QWidget()
spatial_layout = QtWidgets.QVBoxLayout(self.section_spatial)
spatial_layout.setContentsMargins(6, 6, 6, 6)
@@ -1318,6 +1349,7 @@ def _build_ui(self) -> None:
export_layout.addWidget(self.btn_export_img)
export_layout.addWidget(self.btn_save_cfg)
export_layout.addWidget(self.btn_load_cfg)
+ export_layout.addWidget(self.btn_reset_cfg)
export_layout.addWidget(self.btn_new_project)
export_layout.addWidget(self.btn_save_project)
export_layout.addWidget(self.btn_load_project)
@@ -1353,6 +1385,12 @@ def _build_ui(self) -> None:
self.btn_action_compute.setProperty("class", "compactPrimarySmall")
self.btn_action_export = QtWidgets.QPushButton("Export")
self.btn_action_export.setProperty("class", "compactPrimarySmall")
+ self.btn_action_undo = QtWidgets.QPushButton("Undo")
+ self.btn_action_undo.setProperty("class", "compactSmall")
+ self.btn_action_undo.setToolTip("Undo last postprocessing setting or view action")
+ self.btn_action_redo = QtWidgets.QPushButton("Redo")
+ self.btn_action_redo.setProperty("class", "compactSmall")
+ self.btn_action_redo.setToolTip("Redo last undone postprocessing setting or view action")
self.btn_action_hide = QtWidgets.QPushButton("Hide Panels")
self.btn_action_hide.setProperty("class", "compactSmall")
@@ -1362,6 +1400,7 @@ def _build_ui(self) -> None:
self.btn_panel_export = QtWidgets.QPushButton("Export")
self.btn_panel_signal = QtWidgets.QPushButton("Signal")
self.btn_panel_behavior = QtWidgets.QPushButton("Behavior")
+ self.btn_panel_temporal = QtWidgets.QPushButton("Temporal")
self._section_buttons = {
"setup": self.btn_panel_setup,
"psth": self.btn_panel_psth,
@@ -1369,6 +1408,7 @@ def _build_ui(self) -> None:
"export": self.btn_panel_export,
"signal": self.btn_panel_signal,
"behavior": self.btn_panel_behavior,
+ "temporal": self.btn_panel_temporal,
}
for b in self._section_buttons.values():
b.setCheckable(True)
@@ -1378,7 +1418,7 @@ def _build_ui(self) -> None:
# and put workflow actions in a thin transport bar above the plots.
from styles import (
_make_icon, _paint_sliders, _paint_chart, _paint_grid,
- _paint_export, _paint_pulse, _paint_paw,
+ _paint_export, _paint_pulse, _paint_paw, _paint_temporal,
)
_post_rail_meta = {
"setup": ("Setup", _paint_sliders),
@@ -1387,6 +1427,7 @@ def _build_ui(self) -> None:
"export": ("Export panel", _paint_export),
"signal": ("Signal events", _paint_pulse),
"behavior": ("Behavior", _paint_paw),
+ "temporal": ("Temporal modeling (GLM / FLMM)", _paint_temporal),
}
for key, btn in self._section_buttons.items():
tip, painter = _post_rail_meta[key]
@@ -1404,7 +1445,7 @@ def _build_ui(self) -> None:
rail_layout = QtWidgets.QVBoxLayout(self._post_side_rail)
rail_layout.setContentsMargins(8, 10, 8, 10)
rail_layout.setSpacing(6)
- for key in ("setup", "psth", "spatial", "signal", "behavior", "export"):
+ for key in ("setup", "psth", "spatial", "temporal", "signal", "behavior", "export"):
rail_layout.addWidget(self._section_buttons[key], 0,
QtCore.Qt.AlignmentFlag.AlignHCenter)
rail_layout.addStretch(1)
@@ -1420,8 +1461,10 @@ def _build_ui(self) -> None:
self.btn_action_export.setText("Run Export")
self.btn_action_hide.setText("Hide drawer")
for b in (self.btn_action_load, self.btn_action_compute,
- self.btn_action_export, self.btn_style):
+ self.btn_action_export, self.btn_action_undo,
+ self.btn_action_redo, self.btn_style):
tb_layout.addWidget(b)
+ self._update_history_buttons()
tb_layout.addStretch(1)
tb_layout.addWidget(self.btn_action_hide)
# action_row is no longer used; left in scope so any later reference
@@ -1520,6 +1563,7 @@ def _build_ui(self) -> None:
self.curve_trace = self.plot_trace.plot(pen=pg.mkPen(self._style["trace"], width=1.1))
self.curve_behavior = self.plot_trace.plot(pen=pg.mkPen(self._style["behavior"], width=1.0))
+ self.curve_behavior.setVisible(False)
self.curve_peak_markers = self.plot_trace.plot(
pen=None,
symbol="o",
@@ -1638,6 +1682,10 @@ def _build_ui(self) -> None:
)
self.plot_metrics.addItem(self.metrics_err_pre)
self.plot_metrics.addItem(self.metrics_err_post)
+ self.metrics_p_text = pg.TextItem("", color=(230, 236, 246), anchor=(0.5, 1.0))
+ self.metrics_p_text.setZValue(20)
+ self.metrics_p_text.setVisible(False)
+ self.plot_metrics.addItem(self.metrics_p_text)
self.plot_metrics.setXRange(-0.5, 1.5, padding=0)
self.plot_metrics.getAxis("bottom").setTicks([[(0, "pre"), (1, "post")]])
@@ -1844,6 +1892,8 @@ def _build_ui(self) -> None:
self.act_open_plot_style.triggered.connect(self._open_style_dialog)
self.btn_action_compute.clicked.connect(self._compute_psth)
self.btn_action_export.clicked.connect(self._export_results)
+ self.btn_action_undo.clicked.connect(self._undo_post_action)
+ self.btn_action_redo.clicked.connect(self._redo_post_action)
self.btn_action_hide.clicked.connect(self._hide_all_section_popups)
for key, btn in self._section_buttons.items():
btn.toggled.connect(lambda checked, section_key=key: self._toggle_section_popup(section_key, checked))
@@ -1875,6 +1925,7 @@ def _build_ui(self) -> None:
self.btn_style.clicked.connect(self._open_style_dialog)
self.btn_save_cfg.clicked.connect(self._save_config_file)
self.btn_load_cfg.clicked.connect(self._load_config_file)
+ self.btn_reset_cfg.clicked.connect(self._reset_config_defaults)
self.btn_new_project.clicked.connect(self._new_project)
self.btn_save_project.clicked.connect(self._save_project_file)
self.btn_load_project.clicked.connect(self._load_project_file)
@@ -1889,9 +1940,15 @@ def _build_ui(self) -> None:
self.cb_peak_overlay.toggled.connect(self._refresh_signal_overlay)
self.combo_signal_source.currentIndexChanged.connect(self._refresh_signal_file_combo)
self.combo_signal_scope.currentIndexChanged.connect(self._refresh_signal_file_combo)
+ self.combo_signal_file.currentIndexChanged.connect(self._on_signal_file_changed)
+ self.cb_peak_auto_mad.toggled.connect(self._update_peak_auto_mad_enabled)
+ self.cb_peak_noise_overlay.toggled.connect(self._refresh_signal_overlay)
+ self.cb_peak_norm_prominence.toggled.connect(lambda _checked=False: self._save_settings())
self.tab_sources.currentChanged.connect(self._refresh_signal_file_combo)
self.tab_visual_mode.currentChanged.connect(self._on_visual_mode_changed)
+ self.tab_visual_mode.currentChanged.connect(self._queue_settings_save)
self.combo_individual_file.currentIndexChanged.connect(self._on_individual_file_changed)
+ self.combo_individual_file.currentIndexChanged.connect(self._queue_settings_save)
self.combo_align.currentIndexChanged.connect(self._update_align_ui)
self.combo_behavior_file_type.currentIndexChanged.connect(self._update_align_ui)
@@ -2025,6 +2082,7 @@ def _section_widget_map(self) -> Dict[str, Tuple[str, QtWidgets.QWidget]]:
"setup": ("Setup", self.section_setup),
"psth": ("PSTH", self.section_psth),
"spatial": ("Spatial", self.section_spatial),
+ "temporal": ("Temporal Modeling", self.section_temporal),
"export": ("Export", self.section_export),
"signal": ("Signal Event Analyzer", self.section_signal),
"behavior": ("Behavior Analysis", self.section_behavior),
@@ -2376,6 +2434,7 @@ def _setup_section_popups(self) -> None:
"signal": ("Signal Event Analyzer", self.section_signal),
"behavior": ("Behavior Analysis", self.section_behavior),
"spatial": ("Spatial", self.section_spatial),
+ "temporal": ("Temporal Modeling", self.section_temporal),
"export": ("Export", self.section_export),
}
for key, (title, widget) in section_map.items():
@@ -2393,6 +2452,7 @@ def _setup_section_popups(self) -> None:
dock = QtWidgets.QDockWidget(title, host)
dock.setObjectName(f"post.{key}.dock")
+ dock.setWindowModality(QtCore.Qt.WindowModality.NonModal)
dock.setAllowedAreas(QtCore.Qt.DockWidgetArea.AllDockWidgetAreas)
dock.setMinimumWidth(320)
dock.setFeatures(
@@ -2590,7 +2650,8 @@ def _toggle_section_popup(self, key: str, checked: bool) -> None:
self._section_popup_initialized.add(key)
dock.show()
dock.raise_()
- dock.activateWindow()
+ if key != "temporal":
+ dock.activateWindow()
self._last_opened_section = key
else:
dock.hide()
@@ -3224,6 +3285,29 @@ def _load_behavior_paths(self, paths: List[str], replace: bool) -> None:
self._project_dirty = True
self._update_data_availability()
self._update_status_strip()
+ self._sync_temporal_modeling_context()
+
+ def _sync_temporal_modeling_context(self) -> None:
+ if not hasattr(self, "section_temporal"):
+ return
+ try:
+ self.section_temporal.set_data(
+ processed_trials=self._processed,
+ psth_mat=self._last_mat,
+ psth_tvec=self._last_tvec,
+ event_times=self._last_events,
+ file_ids=self._all_file_ids,
+ per_file_mats=self._per_file_mats,
+ behavior_sources=self._behavior_sources,
+ event_rows=self._last_event_rows,
+ group_mat=self._group_mat,
+ group_tvec=self._group_tvec,
+ group_labels=self._group_labels,
+ visual_mode=int(self.tab_visual_mode.currentIndex()) if hasattr(self, "tab_visual_mode") else 0,
+ group_mode=bool(self.tab_sources.currentIndex() == 1) if hasattr(self, "tab_sources") else False,
+ )
+ except Exception:
+ _LOG.debug("Could not sync temporal modeling context", exc_info=True)
def _load_processed_paths(self, paths: List[str], replace: bool) -> None:
loaded: List[ProcessedTrial] = []
@@ -3415,6 +3499,50 @@ def _refresh_signal_file_combo(self) -> None:
has_multi = self.combo_signal_file.count() > 1
self.combo_signal_scope.setEnabled(has_multi)
self.combo_signal_file.setEnabled(self.combo_signal_scope.currentText() == "Per file")
+ self._refresh_signal_overlay()
+
+ def _on_signal_file_changed(self, _index: int = 0) -> None:
+ if not hasattr(self, "combo_signal_file"):
+ return
+ if self.combo_signal_source.currentText().startswith("Use PSTH input trace"):
+ self._refresh_signal_overlay()
+ return
+ if self.combo_signal_scope.currentText() != "Per file":
+ self._refresh_signal_overlay()
+ return
+
+ file_id = self.combo_signal_file.currentText().strip()
+ if not file_id:
+ self._refresh_signal_overlay()
+ return
+ try:
+ idx = self.combo_individual_file.findText(file_id)
+ if idx >= 0 and self.combo_individual_file.currentIndex() != idx:
+ self.combo_individual_file.setCurrentIndex(idx)
+ else:
+ self._update_trace_preview()
+ except Exception:
+ self._update_trace_preview()
+ self._refresh_signal_overlay()
+
+ def _current_signal_overlay_file_id(self) -> str:
+ if self.combo_signal_source.currentText().startswith("Use PSTH input trace"):
+ return "psth_trace"
+ if self.combo_signal_scope.currentText() == "Per file":
+ file_id = self.combo_signal_file.currentText().strip()
+ if file_id:
+ return file_id
+ try:
+ if self.tab_visual_mode.currentIndex() == 0:
+ file_id = self.combo_individual_file.currentText().strip()
+ if file_id:
+ return file_id
+ except Exception:
+ pass
+ if self._processed:
+ proc = self._processed[0]
+ return os.path.splitext(os.path.basename(proc.path))[0] if proc.path else "import"
+ return ""
def _refresh_individual_file_combo(self) -> None:
if not hasattr(self, "combo_individual_file"):
@@ -3437,10 +3565,44 @@ def _on_visual_mode_changed(self, index: int) -> None:
is_individual = index == 0
self.combo_individual_file.setVisible(is_individual)
self._rerender_visual_from_cache()
+ # Mirror to the Temporal Modeling scope: Individual -> Active file, Group -> Per-file batch.
+ try:
+ section = getattr(self, "section_temporal", None)
+ if section is not None and hasattr(section, "combo_fit_scope"):
+ target = "active" if is_individual else "batch"
+ idx = section.combo_fit_scope.findData(target)
+ if idx >= 0 and section.combo_fit_scope.currentIndex() != idx:
+ section.combo_fit_scope.blockSignals(True)
+ section.combo_fit_scope.setCurrentIndex(idx)
+ section.combo_fit_scope.blockSignals(False)
+ section._fit_mode = target
+ section._update_fit_state_label()
+ except Exception:
+ pass
def _on_individual_file_changed(self, _index: int = 0) -> None:
if self.tab_visual_mode.currentIndex() == 0:
self._rerender_visual_from_cache()
+ # Push the new active file into the Temporal Modeling section so its
+ # scope strip stays in sync with the global Individual file picker.
+ try:
+ file_id = self.combo_individual_file.currentText().strip()
+ section = getattr(self, "section_temporal", None)
+ if section is None or not file_id:
+ return
+ combo = getattr(section, "combo_active_file", None)
+ if combo is not None:
+ idx = combo.findData(file_id)
+ if idx < 0:
+ idx = combo.findText(file_id)
+ if idx >= 0 and combo.currentIndex() != idx:
+ combo.blockSignals(True)
+ combo.setCurrentIndex(idx)
+ combo.blockSignals(False)
+ section._active_file_id = file_id
+ section._on_active_file_changed()
+ except Exception:
+ pass
def _rerender_visual_from_cache(self) -> None:
visual_mode = self.tab_visual_mode.currentIndex()
@@ -3471,6 +3633,7 @@ def _rerender_visual_from_cache(self) -> None:
self.plot_avg.setTitle("Average across trials +/- SEM")
# Always refresh the trace preview to match the selected file
self._update_trace_preview()
+ self._sync_temporal_modeling_context()
def _update_data_availability(self) -> None:
has_processed = bool(self._processed)
@@ -3490,17 +3653,21 @@ def _update_data_availability(self) -> None:
self.combo_signal_scope,
self.combo_signal_file,
self.combo_signal_method,
- self.spin_peak_prominence,
+ self.cb_peak_auto_mad,
+ self.spin_peak_mad_multiplier,
self.spin_peak_height,
self.spin_peak_distance,
self.spin_peak_smooth,
self.combo_peak_baseline,
self.spin_peak_baseline_window,
+ self.cb_peak_norm_prominence,
self.spin_peak_rate_bin,
self.spin_peak_auc_window,
self.cb_peak_overlay,
+ self.cb_peak_noise_overlay,
):
w.setEnabled(has_processed)
+ self._update_peak_auto_mad_enabled(queue=False)
for w in (
self.btn_compute_behavior,
self.btn_export_behavior_metrics,
@@ -3600,9 +3767,120 @@ def _update_global_metrics_enabled(self) -> None:
self._apply_view_layout()
self._queue_settings_save()
+ def _update_peak_auto_mad_enabled(self, _checked: object = None, *, queue: bool = True) -> None:
+ has_processed = bool(self._processed)
+ auto_mad = bool(self.cb_peak_auto_mad.isChecked())
+ self.spin_peak_prominence.setEnabled(has_processed and not auto_mad)
+ self.spin_peak_mad_multiplier.setEnabled(has_processed and auto_mad)
+ if queue:
+ self._queue_settings_save()
+
+ def _history_snapshot(self) -> Dict[str, object]:
+ state = self._collect_settings()
+ try:
+ state["visual_mode"] = int(self.tab_visual_mode.currentIndex())
+ except Exception:
+ state["visual_mode"] = 0
+ try:
+ state["individual_file"] = self.combo_individual_file.currentText().strip()
+ except Exception:
+ state["individual_file"] = ""
+ return copy.deepcopy(state)
+
+ def _history_state_key(self, state: Dict[str, object]) -> str:
+ try:
+ return json.dumps(state, sort_keys=True, separators=(",", ":"), default=str)
+ except Exception:
+ return repr(state)
+
+ def _update_history_buttons(self) -> None:
+ for button, enabled in (
+ (getattr(self, "btn_action_undo", None), bool(self._history_undo)),
+ (getattr(self, "btn_action_redo", None), bool(self._history_redo)),
+ ):
+ if button is not None:
+ button.setEnabled(enabled)
+
+ def _reset_history_snapshot(self) -> None:
+ if not hasattr(self, "combo_align"):
+ return
+ self._history_undo.clear()
+ self._history_redo.clear()
+ self._history_current = self._history_snapshot()
+ self._history_key = self._history_state_key(self._history_current)
+ self._update_history_buttons()
+
+ def _record_history_change(self) -> None:
+ if self._is_restoring_settings or self._history_restoring:
+ return
+ state = self._history_snapshot()
+ key = self._history_state_key(state)
+ if self._history_current is None:
+ self._history_current = copy.deepcopy(state)
+ self._history_key = key
+ self._update_history_buttons()
+ return
+ if key == self._history_key:
+ return
+ self._history_undo.append(copy.deepcopy(self._history_current))
+ if len(self._history_undo) > self._history_limit:
+ self._history_undo = self._history_undo[-self._history_limit:]
+ self._history_redo.clear()
+ self._history_current = copy.deepcopy(state)
+ self._history_key = key
+ self._update_history_buttons()
+
+ def _restore_history_state(self, state: Dict[str, object]) -> None:
+ was_restoring = self._is_restoring_settings
+ self._history_restoring = True
+ self._is_restoring_settings = True
+ try:
+ self._apply_settings(copy.deepcopy(state))
+ if "visual_mode" in state:
+ idx = int(state.get("visual_mode") or 0)
+ if 0 <= idx < self.tab_visual_mode.count():
+ self.tab_visual_mode.setCurrentIndex(idx)
+ if "individual_file" in state:
+ file_id = str(state.get("individual_file") or "").strip()
+ if file_id:
+ combo_idx = self.combo_individual_file.findText(file_id)
+ if combo_idx >= 0:
+ self.combo_individual_file.setCurrentIndex(combo_idx)
+ self._rerender_visual_from_cache()
+ self._update_trace_preview()
+ self._save_settings()
+ finally:
+ self._is_restoring_settings = was_restoring
+ self._history_restoring = False
+
+ def _undo_post_action(self) -> None:
+ if not self._history_undo:
+ return
+ current = self._history_snapshot()
+ previous = self._history_undo.pop()
+ self._history_redo.append(copy.deepcopy(current))
+ self._restore_history_state(previous)
+ self._history_current = copy.deepcopy(previous)
+ self._history_key = self._history_state_key(previous)
+ self._update_history_buttons()
+ self.statusUpdate.emit("Undid postprocessing action.", 2500)
+
+ def _redo_post_action(self) -> None:
+ if not self._history_redo:
+ return
+ current = self._history_snapshot()
+ next_state = self._history_redo.pop()
+ self._history_undo.append(copy.deepcopy(current))
+ self._restore_history_state(next_state)
+ self._history_current = copy.deepcopy(next_state)
+ self._history_key = self._history_state_key(next_state)
+ self._update_history_buttons()
+ self.statusUpdate.emit("Redid postprocessing action.", 2500)
+
def _queue_settings_save(self, *_args: object) -> None:
if self._is_restoring_settings:
return
+ self._record_history_change()
if not self._autosave_restoring:
self._project_dirty = True
timer = getattr(self, "_settings_save_timer", None)
@@ -3656,6 +3934,7 @@ def _wire_settings_autosave(self) -> None:
self.spin_global_start,
self.spin_global_end,
self.spin_peak_prominence,
+ self.spin_peak_mad_multiplier,
self.spin_peak_height,
self.spin_peak_distance,
self.spin_peak_smooth,
@@ -3679,7 +3958,10 @@ def _wire_settings_autosave(self) -> None:
self.cb_global_metrics,
self.cb_global_amp,
self.cb_global_freq,
+ self.cb_peak_auto_mad,
+ self.cb_peak_norm_prominence,
self.cb_peak_overlay,
+ self.cb_peak_noise_overlay,
self.cb_behavior_aligned,
self.cb_spatial_clip,
self.cb_spatial_time_filter,
@@ -3832,7 +4114,6 @@ def _find_col(names: List[str]) -> Optional[int]:
iso_vals = []
dio_vals = []
- from pyBer.analysis_core import coerce_time_value
for r in data_rows:
if time_idx is None or output_idx is None:
continue
@@ -3970,6 +4251,7 @@ def _refresh_behavior_list(self) -> None:
self._refresh_spatial_columns()
self._compute_spatial_heatmap()
self._update_data_availability()
+ self._sync_temporal_modeling_context()
return
behavior_names: set[str] = set()
for info in self._behavior_sources.values():
@@ -3992,6 +4274,7 @@ def _refresh_behavior_list(self) -> None:
self._compute_spatial_heatmap()
self._update_data_availability()
self._update_status_strip()
+ self._sync_temporal_modeling_context()
def _guess_spatial_column(self, columns: List[str], axis: str) -> Optional[str]:
if not columns:
@@ -5112,55 +5395,10 @@ def _update_trace_preview(self) -> None:
self._refresh_signal_overlay()
def _update_behavior_overlay(self, proc: ProcessedTrial) -> None:
- if not proc or proc.time is None:
- self.curve_behavior.setData([], [])
- return
- if not self.combo_align.currentText().startswith("Behavior"):
- self.curve_behavior.setData([], [])
- return
- info = self._match_behavior_source(proc)
- if not info:
- self.curve_behavior.setData([], [])
- return
- behaviors = info.get("behaviors") or {}
- beh = self.combo_behavior_name.currentText().strip()
- if not beh and behaviors:
- beh = next(iter(behaviors.keys()))
- if beh not in behaviors:
- self.curve_behavior.setData([], [])
- return
- t_proc = np.asarray(proc.time, float)
- if t_proc.size == 0:
- self.curve_behavior.setData([], [])
- return
- kind = str(info.get("kind", _BEHAVIOR_PARSE_BINARY))
- if kind == _BEHAVIOR_PARSE_TIMESTAMPS:
- events = np.asarray(behaviors[beh], float)
- events = events[np.isfinite(events)]
- if events.size == 0:
- self.curve_behavior.setData([], [])
- return
- events = np.sort(np.unique(events))
- marker = np.zeros_like(t_proc, dtype=float)
- for ev in events:
- pos = int(np.searchsorted(t_proc, ev, side="left"))
- if pos <= 0:
- idx = 0
- elif pos >= t_proc.size:
- idx = t_proc.size - 1
- else:
- idx = pos if abs(float(t_proc[pos] - ev)) <= abs(float(t_proc[pos - 1] - ev)) else (pos - 1)
- marker[idx] = 1.0
- self.curve_behavior.setData(t_proc, marker, connect="finite", skipFiniteCheck=True)
- return
-
- t = np.asarray(info.get("time", np.array([], float)), float)
- if t.size == 0:
- self.curve_behavior.setData([], [])
- return
- b = np.asarray(behaviors[beh], float)
- b_interp = np.interp(t_proc, t, b)
- self.curve_behavior.setData(t_proc, b_interp, connect="finite", skipFiniteCheck=True)
+ # The trace preview should only show the processed signal trace.
+ # Behavior/event data remain available for alignment and analysis,
+ # but the binary edge overlay is intentionally not rendered here.
+ self.curve_behavior.setData([], [])
def _resolve_signal_detection_targets(self) -> List[Tuple[str, np.ndarray, np.ndarray]]:
targets: List[Tuple[str, np.ndarray, np.ndarray]] = []
@@ -5195,14 +5433,14 @@ def _resolve_signal_detection_targets(self) -> List[Tuple[str, np.ndarray, np.nd
targets.append((file_id, np.asarray(proc.time, float), np.asarray(proc.output, float)))
return targets
- def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
t = np.asarray(t, float)
y = np.asarray(y, float)
m = np.isfinite(t) & np.isfinite(y)
t = t[m]
y = y[m]
if t.size < 3:
- return np.array([], float), np.array([], float)
+ return np.array([], float), np.array([], float), np.array([], float)
dt = float(np.nanmedian(np.diff(t))) if t.size > 2 else np.nan
if not np.isfinite(dt) or dt <= 0:
@@ -5214,6 +5452,7 @@ def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np
if win % 2 == 0:
win += 1
+ y_trace = y.copy()
y_proc = y.copy()
if baseline_mode.endswith("rolling median"):
try:
@@ -5237,7 +5476,121 @@ def _preprocess_signal_for_peaks(self, t: np.ndarray, y: np.ndarray) -> Tuple[np
except Exception:
pass
- return t, y_proc
+ return t, y_proc, y_trace
+
+ def _signal_baseline_prominence_stats(
+ self,
+ t: np.ndarray,
+ y: np.ndarray,
+ min_prominence: float,
+ ) -> Dict[str, float]:
+ t = np.asarray(t, float)
+ y = np.asarray(y, float)
+ finite = np.isfinite(t) & np.isfinite(y)
+ if t.size < 3 or y.size != t.size or not np.any(finite):
+ return {
+ "scale": np.nan,
+ "baseline_median": np.nan,
+ "n_baseline_peaks": 0.0,
+ "baseline_duration_s": 0.0,
+ "scale_source": "unavailable",
+ "mad_noise_sigma": np.nan,
+ }
+
+ t_finite = t[finite]
+ t0 = float(np.nanmin(t_finite))
+ baseline_window = max(0.1, float(self.spin_peak_baseline_window.value()))
+ keep = finite & (t <= t0 + baseline_window)
+ if np.sum(keep) < 3:
+ keep = finite
+
+ baseline = y[keep]
+ baseline_median = float(np.nanmedian(baseline)) if baseline.size else np.nan
+ if not np.isfinite(baseline_median):
+ return {
+ "scale": np.nan,
+ "baseline_median": np.nan,
+ "n_baseline_peaks": 0.0,
+ "baseline_duration_s": 0.0,
+ "scale_source": "unavailable",
+ "mad_noise_sigma": np.nan,
+ }
+
+ centered_baseline = np.asarray(baseline, float) - baseline_median
+ mad_stats = self._signal_mad_noise_stats(centered_baseline)
+ mad_sigma = float(mad_stats.get("noise_sigma", np.nan))
+ if not np.isfinite(mad_sigma) or mad_sigma <= 1e-12:
+ full_centered = np.asarray(y[finite], float) - float(np.nanmedian(y[finite]))
+ full_mad_stats = self._signal_mad_noise_stats(full_centered)
+ mad_sigma = float(full_mad_stats.get("noise_sigma", np.nan))
+ try:
+ from scipy.signal import find_peaks
+ peaks, props = find_peaks(centered_baseline, prominence=max(0.0, float(min_prominence)))
+ except Exception:
+ peaks = np.array([], int)
+ props = {}
+ proms = np.asarray(props.get("prominences", np.array([], float)), float)
+ proms = proms[np.isfinite(proms) & (proms > 1e-12)]
+ if proms.size == 0:
+ scale = mad_sigma if np.isfinite(mad_sigma) and mad_sigma > 1e-12 else np.nan
+ scale_source = "mad_noise_fallback" if np.isfinite(scale) else "unavailable"
+ else:
+ top_count = max(1, int(np.ceil(proms.size * 0.10)))
+ scale = float(np.nanmean(np.sort(proms)[-top_count:]))
+ scale_source = "baseline_peak_prominence"
+
+ duration = float(np.nanmax(t[keep]) - np.nanmin(t[keep])) if np.sum(keep) >= 2 else 0.0
+ return {
+ "scale": scale,
+ "baseline_median": baseline_median,
+ "n_baseline_peaks": float(proms.size),
+ "baseline_duration_s": duration,
+ "scale_source": scale_source,
+ "mad_noise_sigma": mad_sigma,
+ }
+
+ @staticmethod
+ def _signal_mad_noise_stats(y: np.ndarray) -> Dict[str, float]:
+ arr = np.asarray(y, float)
+ arr = arr[np.isfinite(arr)]
+ if arr.size < 5:
+ return {"center": np.nan, "mad": np.nan, "noise_sigma": np.nan, "n_samples": float(arr.size)}
+
+ center = float(np.nanmedian(arr))
+ abs_dev = np.abs(arr - center)
+ mad = float(np.nanmedian(abs_dev))
+ sigma = 1.4826 * mad
+
+ # Re-estimate from the central mass so large transients do not inflate the noise estimate.
+ if np.isfinite(sigma) and sigma > 1e-12:
+ keep = abs_dev <= (3.0 * sigma)
+ if np.sum(keep) >= max(5, int(0.10 * arr.size)):
+ core = arr[keep]
+ center = float(np.nanmedian(core))
+ mad = float(np.nanmedian(np.abs(core - center)))
+ sigma = 1.4826 * mad
+
+ if not np.isfinite(sigma) or sigma <= 1e-12:
+ q25, q75 = np.nanpercentile(arr, [25.0, 75.0])
+ sigma = float((q75 - q25) / 1.349) if np.isfinite(q25) and np.isfinite(q75) else np.nan
+ if not np.isfinite(sigma) or sigma <= 1e-12:
+ sigma = float(np.nanstd(arr))
+ if not np.isfinite(sigma) or sigma <= 1e-12:
+ sigma = np.nan
+
+ return {
+ "center": center,
+ "mad": mad,
+ "noise_sigma": sigma,
+ "n_samples": float(arr.size),
+ }
+
+ @staticmethod
+ def _trapz_area(y: np.ndarray, x: np.ndarray) -> float:
+ try:
+ return float(np.trapezoid(y, x))
+ except AttributeError:
+ return float(np.trapz(y, x))
def _refresh_signal_overlay(self) -> None:
for ln in self._signal_peak_lines:
@@ -5246,21 +5599,32 @@ def _refresh_signal_overlay(self) -> None:
except Exception:
pass
self._signal_peak_lines = []
+ for item in self._signal_noise_items:
+ try:
+ self.plot_trace.removeItem(item)
+ except Exception:
+ pass
+ self._signal_noise_items = []
self.curve_peak_markers.setData([], [])
- if not self.cb_peak_overlay.isChecked():
- return
if not self.last_signal_events or not self._processed:
return
- current_file = os.path.splitext(os.path.basename(self._processed[0].path))[0] if self._processed[0].path else "import"
+ current_file = self._current_signal_overlay_file_id()
+ self._draw_signal_noise_overlay(current_file)
+
+ if not self.cb_peak_overlay.isChecked():
+ return
+
file_ids = self.last_signal_events.get("file_ids", [])
times = np.asarray(self.last_signal_events.get("peak_times_sec", np.array([], float)), float)
- heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float)
+ heights = np.asarray(self.last_signal_events.get("peak_trace_values", np.array([], float)), float)
+ if heights.size != times.size:
+ heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float)
if times.size == 0 or heights.size == 0:
return
- if file_ids and len(file_ids) == times.size:
- mask = np.asarray([fid == current_file or fid == "psth_trace" for fid in file_ids], bool)
+ if current_file and file_ids and len(file_ids) == times.size:
+ mask = np.asarray([str(fid) == current_file for fid in file_ids], bool)
times = times[mask]
heights = heights[mask]
if times.size == 0:
@@ -5275,6 +5639,55 @@ def _refresh_signal_overlay(self) -> None:
self.plot_trace.addItem(ln)
self._signal_peak_lines.append(ln)
+ def _draw_signal_noise_overlay(self, current_file: str) -> None:
+ if not getattr(self, "cb_peak_noise_overlay", None) or not self.cb_peak_noise_overlay.isChecked():
+ return
+ overlays = self.last_signal_events.get("noise_overlay_by_file", {}) if self.last_signal_events else {}
+ if not isinstance(overlays, dict) or not overlays:
+ return
+ overlay = overlays.get(str(current_file or ""))
+ if overlay is None and len(overlays) == 1:
+ overlay = next(iter(overlays.values()))
+ if not isinstance(overlay, dict):
+ return
+
+ t = np.asarray(overlay.get("time", np.array([], float)), float)
+ y = np.asarray(overlay.get("detection_trace", np.array([], float)), float)
+ if t.size != y.size or t.size < 2:
+ return
+
+ step = max(1, int(np.ceil(t.size / 5000)))
+ trace_item = pg.PlotDataItem(
+ t[::step],
+ y[::step],
+ pen=pg.mkPen((80, 220, 220, 150), width=1.0, style=QtCore.Qt.PenStyle.DashLine),
+ name="detection trace",
+ )
+ trace_item.setZValue(8)
+ self.plot_trace.addItem(trace_item)
+ self._signal_noise_items.append(trace_item)
+
+ center = float(overlay.get("center", np.nan))
+ sigma = float(overlay.get("noise_sigma", np.nan))
+ used_prominence = float(overlay.get("used_prominence", np.nan))
+ for y0, color, width, style in (
+ (center, (80, 220, 220, 170), 1.0, QtCore.Qt.PenStyle.DotLine),
+ (center + sigma, (80, 220, 220, 110), 0.8, QtCore.Qt.PenStyle.DotLine),
+ (center - sigma, (80, 220, 220, 110), 0.8, QtCore.Qt.PenStyle.DotLine),
+ (center + used_prominence, (255, 210, 80, 190), 1.2, QtCore.Qt.PenStyle.DashLine),
+ ):
+ if not np.isfinite(y0):
+ continue
+ ln = pg.InfiniteLine(
+ pos=float(y0),
+ angle=0,
+ pen=pg.mkPen(color, width=width, style=style),
+ movable=False,
+ )
+ ln.setZValue(9)
+ self.plot_trace.addItem(ln)
+ self._signal_noise_items.append(ln)
+
def _detect_signal_events(self) -> None:
self.last_signal_events = None
targets = self._resolve_signal_detection_targets()
@@ -5287,20 +5700,57 @@ def _detect_signal_events(self) -> None:
all_times: List[float] = []
all_idx: List[int] = []
all_heights: List[float] = []
+ all_signal_heights: List[float] = []
+ all_trace_values: List[float] = []
all_proms: List[float] = []
+ all_norm_proms: List[float] = []
+ all_norm_scales: List[float] = []
+ all_mad_sigmas: List[float] = []
+ all_auto_prominence_thresholds: List[float] = []
all_widths_sec: List[float] = []
all_auc: List[float] = []
all_file_ids: List[str] = []
+ normalization_by_file: Dict[str, Dict[str, float]] = {}
+ mad_threshold_by_file: Dict[str, Dict[str, float]] = {}
+ noise_overlay_by_file: Dict[str, Dict[str, object]] = {}
+ normalize_amplitude = bool(self.cb_peak_norm_prominence.isChecked())
+ auto_mad = bool(self.cb_peak_auto_mad.isChecked())
+ mad_multiplier = float(self.spin_peak_mad_multiplier.value())
for file_id, t_raw, y_raw in targets:
- t, y = self._preprocess_signal_for_peaks(t_raw, y_raw)
+ t, y, y_trace = self._preprocess_signal_for_peaks(t_raw, y_raw)
if t.size < 5:
continue
dt = float(np.nanmedian(np.diff(t)))
if not np.isfinite(dt) or dt <= 0:
continue
- prominence = max(0.0, float(self.spin_peak_prominence.value()))
+ manual_prominence = max(0.0, float(self.spin_peak_prominence.value()))
+ prominence = manual_prominence
+ mad_stats = self._signal_mad_noise_stats(y)
+ mad_sigma = float(mad_stats.get("noise_sigma", np.nan))
+ auto_prominence = np.nan
+ if auto_mad:
+ if np.isfinite(mad_sigma) and mad_sigma > 1e-12:
+ auto_prominence = max(0.0, mad_multiplier * mad_sigma)
+ prominence = auto_prominence
+ mad_threshold_by_file[str(file_id)] = {
+ **mad_stats,
+ "multiplier": mad_multiplier,
+ "auto_prominence": auto_prominence,
+ "fallback_manual_prominence": manual_prominence,
+ "used_prominence": prominence,
+ }
+ noise_overlay_by_file[str(file_id)] = {
+ "time": t.copy(),
+ "detection_trace": y.copy(),
+ "center": float(mad_stats.get("center", np.nan)),
+ "mad": float(mad_stats.get("mad", np.nan)),
+ "noise_sigma": mad_sigma,
+ "auto_prominence": auto_prominence,
+ "manual_prominence": manual_prominence,
+ "used_prominence": prominence,
+ }
min_height = float(self.spin_peak_height.value())
min_distance_sec = max(0.0, float(self.spin_peak_distance.value()))
min_dist_samples = max(1, int(round(min_distance_sec / dt))) if min_distance_sec > 0 else None
@@ -5325,8 +5775,24 @@ def _detect_signal_events(self) -> None:
if peaks.size == 0:
continue
- p_heights = y[peaks]
+ p_signal_heights = np.asarray(y[peaks], float)
+ p_trace_values = np.asarray(y_trace[peaks], float) if y_trace.size == y.size else p_signal_heights.copy()
p_proms = np.asarray(props.get("prominences", np.full(peaks.size, np.nan)), float)
+ p_heights = p_signal_heights.copy()
+ p_norm_proms = np.full(peaks.size, np.nan, float)
+ p_norm_scales = np.full(peaks.size, np.nan, float)
+
+ if normalize_amplitude:
+ norm_stats = self._signal_baseline_prominence_stats(t, y, prominence)
+ normalization_by_file[str(file_id)] = norm_stats
+ scale = float(norm_stats.get("scale", np.nan))
+ baseline_median = float(norm_stats.get("baseline_median", np.nan))
+ if np.isfinite(scale) and scale > 1e-12 and np.isfinite(baseline_median):
+ p_heights = (p_signal_heights - baseline_median) / scale
+ p_norm_proms = p_proms / scale
+ p_norm_scales[:] = scale
+ else:
+ p_heights = np.full(peaks.size, np.nan, float)
try:
widths_samp = peak_widths(y, peaks, rel_height=0.5)[0]
widths_sec = np.asarray(widths_samp, float) * dt
@@ -5345,12 +5811,18 @@ def _detect_signal_events(self) -> None:
if i1 - i0 < 2:
auc_vals.append(np.nan)
continue
- auc_vals.append(float(np.trapz(y[i0:i1], t[i0:i1])))
+ auc_vals.append(self._trapz_area(y[i0:i1], t[i0:i1]))
all_times.extend(t[peaks].tolist())
all_idx.extend(peaks.tolist())
all_heights.extend(np.asarray(p_heights, float).tolist())
+ all_signal_heights.extend(np.asarray(p_signal_heights, float).tolist())
+ all_trace_values.extend(np.asarray(p_trace_values, float).tolist())
all_proms.extend(np.asarray(p_proms, float).tolist())
+ all_norm_proms.extend(np.asarray(p_norm_proms, float).tolist())
+ all_norm_scales.extend(np.asarray(p_norm_scales, float).tolist())
+ all_mad_sigmas.extend([mad_sigma] * peaks.size)
+ all_auto_prominence_thresholds.extend([auto_prominence] * peaks.size)
all_widths_sec.extend(np.asarray(widths_sec, float).tolist())
all_auc.extend(np.asarray(auc_vals, float).tolist())
all_file_ids.extend([file_id] * peaks.size)
@@ -5364,7 +5836,13 @@ def _detect_signal_events(self) -> None:
peak_times = np.asarray(all_times, float)
peak_idx = np.asarray(all_idx, int)
peak_heights = np.asarray(all_heights, float)
+ peak_signal_heights = np.asarray(all_signal_heights, float)
+ peak_trace_values = np.asarray(all_trace_values, float)
peak_proms = np.asarray(all_proms, float)
+ peak_norm_proms = np.asarray(all_norm_proms, float)
+ peak_norm_scales = np.asarray(all_norm_scales, float)
+ peak_mad_sigmas = np.asarray(all_mad_sigmas, float)
+ peak_auto_prominence_thresholds = np.asarray(all_auto_prominence_thresholds, float)
peak_widths_sec = np.asarray(all_widths_sec, float)
peak_auc = np.asarray(all_auc, float)
@@ -5372,7 +5850,13 @@ def _detect_signal_events(self) -> None:
peak_times = peak_times[sort_idx]
peak_idx = peak_idx[sort_idx]
peak_heights = peak_heights[sort_idx]
+ peak_signal_heights = peak_signal_heights[sort_idx]
+ peak_trace_values = peak_trace_values[sort_idx]
peak_proms = peak_proms[sort_idx]
+ peak_norm_proms = peak_norm_proms[sort_idx]
+ peak_norm_scales = peak_norm_scales[sort_idx]
+ peak_mad_sigmas = peak_mad_sigmas[sort_idx]
+ peak_auto_prominence_thresholds = peak_auto_prominence_thresholds[sort_idx]
peak_widths_sec = peak_widths_sec[sort_idx]
peak_auc = peak_auc[sort_idx]
all_file_ids = [all_file_ids[i] for i in sort_idx]
@@ -5389,31 +5873,115 @@ def _detect_signal_events(self) -> None:
"peak_frequency_per_min": freq_per_min,
"mean_inter_peak_interval_s": float(np.nanmean(ipi)) if ipi.size else np.nan,
"mean_auc": float(np.nanmean(peak_auc)) if np.any(np.isfinite(peak_auc)) else np.nan,
+ "baseline_prominence_normalized": bool(normalize_amplitude),
+ "mad_auto_threshold_enabled": bool(auto_mad),
}
+ if auto_mad:
+ metrics.update(
+ {
+ "mad_multiplier": float(mad_multiplier),
+ "mean_mad_noise_sigma": (
+ float(np.nanmean(peak_mad_sigmas)) if np.any(np.isfinite(peak_mad_sigmas)) else np.nan
+ ),
+ "mean_auto_prominence_threshold": (
+ float(np.nanmean(peak_auto_prominence_thresholds))
+ if np.any(np.isfinite(peak_auto_prominence_thresholds))
+ else np.nan
+ ),
+ "mad_threshold_files_with_estimate": float(
+ sum(
+ 1
+ for stats in mad_threshold_by_file.values()
+ if np.isfinite(float(stats.get("noise_sigma", np.nan)))
+ )
+ ),
+ }
+ )
+ if normalize_amplitude:
+ metrics.update(
+ {
+ "mean_raw_amplitude": float(np.nanmean(peak_signal_heights)),
+ "median_raw_amplitude": float(np.nanmedian(peak_signal_heights)),
+ "mean_normalized_prominence": (
+ float(np.nanmean(peak_norm_proms)) if np.any(np.isfinite(peak_norm_proms)) else np.nan
+ ),
+ "mean_baseline_prominence_scale": (
+ float(np.nanmean(peak_norm_scales)) if np.any(np.isfinite(peak_norm_scales)) else np.nan
+ ),
+ "baseline_prominence_files_with_scale": float(
+ sum(
+ 1
+ for stats in normalization_by_file.values()
+ if np.isfinite(float(stats.get("scale", np.nan)))
+ )
+ ),
+ "baseline_prominence_files_with_peak_scale": float(
+ sum(
+ 1
+ for stats in normalization_by_file.values()
+ if str(stats.get("scale_source", "")) == "baseline_peak_prominence"
+ )
+ ),
+ "baseline_prominence_files_with_mad_fallback": float(
+ sum(
+ 1
+ for stats in normalization_by_file.values()
+ if str(stats.get("scale_source", "")) == "mad_noise_fallback"
+ )
+ ),
+ }
+ )
self.last_signal_events = {
"peak_times_sec": peak_times,
"peak_indices": peak_idx,
"peak_heights": peak_heights,
+ "peak_signal_heights": peak_signal_heights,
+ "peak_trace_values": peak_trace_values,
"peak_prominences": peak_proms,
+ "peak_normalized_prominences": peak_norm_proms,
+ "peak_baseline_prominence_scale": peak_norm_scales,
+ "peak_mad_noise_sigma": peak_mad_sigmas,
+ "peak_auto_prominence_threshold": peak_auto_prominence_thresholds,
"peak_widths_sec": peak_widths_sec,
"peak_auc": peak_auc,
"file_ids": all_file_ids,
"derived_metrics": metrics,
+ "normalization_by_file": normalization_by_file,
+ "mad_threshold_by_file": mad_threshold_by_file,
+ "noise_overlay_by_file": noise_overlay_by_file,
"params": {
"method": self.combo_signal_method.currentText(),
"prominence": float(self.spin_peak_prominence.value()),
+ "auto_mad_threshold": bool(auto_mad),
+ "mad_multiplier": float(mad_multiplier),
"min_height": float(self.spin_peak_height.value()),
"min_distance_sec": float(self.spin_peak_distance.value()),
"smooth_sigma_sec": float(self.spin_peak_smooth.value()),
"baseline_mode": self.combo_peak_baseline.currentText(),
"baseline_window_sec": float(self.spin_peak_baseline_window.value()),
+ "baseline_prominence_normalized": bool(normalize_amplitude),
"rate_bin_sec": float(self.spin_peak_rate_bin.value()),
"auc_half_window_sec": float(self.spin_peak_auc_window.value()),
},
}
- self.statusUpdate.emit(f"Detected {peak_times.size} peak(s).", 5000)
+ msg = f"Detected {peak_times.size} peak(s)."
+ if auto_mad:
+ auto_vals = [
+ float(stats.get("auto_prominence", np.nan))
+ for stats in mad_threshold_by_file.values()
+ if np.isfinite(float(stats.get("auto_prominence", np.nan)))
+ ]
+ if auto_vals:
+ msg += f" Auto-MAD prominence ~{float(np.nanmean(auto_vals)):.4g}."
+ else:
+ msg += " Auto-MAD estimate unavailable; manual prominence used."
+ if normalize_amplitude and not any(np.isfinite(float(s.get("scale", np.nan))) for s in normalization_by_file.values()):
+ msg += " Baseline prominence scale unavailable."
+ elif normalize_amplitude and any(str(s.get("scale_source", "")) == "mad_noise_fallback" for s in normalization_by_file.values()):
+ msg += " Normalization used MAD fallback for files without baseline peaks."
+ self.statusUpdate.emit(msg, 5000)
self._refresh_signal_overlay()
self._render_signal_event_plots()
self._update_signal_metrics_table()
@@ -5431,6 +5999,11 @@ def _render_signal_event_plots(self) -> None:
peak_heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float)
if peak_times.size == 0 or peak_heights.size == 0:
return
+ metrics = self.last_signal_events.get("derived_metrics", {}) or {}
+ if bool(metrics.get("baseline_prominence_normalized", False)):
+ self.plot_peak_amp.setLabel("bottom", "Prominence-normalized amplitude")
+ else:
+ self.plot_peak_amp.setLabel("bottom", "Amplitude")
def _bar_hist(plot: pg.PlotWidget, values: np.ndarray, color: Tuple[int, int, int]) -> None:
vals = np.asarray(values, float)
@@ -5461,10 +6034,13 @@ def _update_signal_metrics_table(self) -> None:
if not self.last_signal_events:
return
metrics = self.last_signal_events.get("derived_metrics", {}) or {}
+ normalized = bool(metrics.get("baseline_prominence_normalized", False))
+ amp_label = "mean amplitude (prom-norm)" if normalized else "mean amplitude"
+ med_amp_label = "median amplitude (prom-norm)" if normalized else "median amplitude"
rows = [
("number of peaks", metrics.get("number_of_peaks", np.nan)),
- ("mean amplitude", metrics.get("mean_amplitude", np.nan)),
- ("median amplitude", metrics.get("median_amplitude", np.nan)),
+ (amp_label, metrics.get("mean_amplitude", np.nan)),
+ (med_amp_label, metrics.get("median_amplitude", np.nan)),
("amplitude std", metrics.get("amplitude_std", np.nan)),
("mean prominence", metrics.get("mean_prominence", np.nan)),
("mean width at half prom (s)", metrics.get("mean_width_half_prom_s", np.nan)),
@@ -5472,6 +6048,26 @@ def _update_signal_metrics_table(self) -> None:
("mean inter-peak interval (s)", metrics.get("mean_inter_peak_interval_s", np.nan)),
("mean AUC", metrics.get("mean_auc", np.nan)),
]
+ if normalized:
+ rows.extend(
+ [
+ ("mean raw amplitude", metrics.get("mean_raw_amplitude", np.nan)),
+ ("mean normalized prominence", metrics.get("mean_normalized_prominence", np.nan)),
+ ("baseline prominence scale", metrics.get("mean_baseline_prominence_scale", np.nan)),
+ ("files with scale", metrics.get("baseline_prominence_files_with_scale", np.nan)),
+ ("files using baseline peaks", metrics.get("baseline_prominence_files_with_peak_scale", np.nan)),
+ ("files using MAD fallback", metrics.get("baseline_prominence_files_with_mad_fallback", np.nan)),
+ ]
+ )
+ if bool(metrics.get("mad_auto_threshold_enabled", False)):
+ rows.extend(
+ [
+ ("MAD multiplier", metrics.get("mad_multiplier", np.nan)),
+ ("mean MAD noise sigma", metrics.get("mean_mad_noise_sigma", np.nan)),
+ ("mean auto prominence", metrics.get("mean_auto_prominence_threshold", np.nan)),
+ ("files with MAD estimate", metrics.get("mad_threshold_files_with_estimate", np.nan)),
+ ]
+ )
for key, value in rows:
r = self.tbl_signal_metrics.rowCount()
self.tbl_signal_metrics.insertRow(r)
@@ -5492,8 +6088,33 @@ def _export_signal_events_csv(self) -> None:
self._remember_export_dir(out_dir)
peak_times = np.asarray(self.last_signal_events.get("peak_times_sec", np.array([], float)), float)
peak_heights = np.asarray(self.last_signal_events.get("peak_heights", np.array([], float)), float)
+ peak_signal_heights = np.asarray(
+ self.last_signal_events.get("peak_signal_heights", peak_heights),
+ float,
+ )
+ peak_trace_values = np.asarray(
+ self.last_signal_events.get("peak_trace_values", peak_signal_heights),
+ float,
+ )
peak_proms = np.asarray(self.last_signal_events.get("peak_prominences", np.array([], float)), float)
+ peak_norm_proms = np.asarray(
+ self.last_signal_events.get("peak_normalized_prominences", np.full_like(peak_proms, np.nan)),
+ float,
+ )
+ peak_norm_scales = np.asarray(
+ self.last_signal_events.get("peak_baseline_prominence_scale", np.full_like(peak_heights, np.nan)),
+ float,
+ )
+ peak_mad_sigmas = np.asarray(
+ self.last_signal_events.get("peak_mad_noise_sigma", np.full_like(peak_heights, np.nan)),
+ float,
+ )
+ peak_auto_prominence = np.asarray(
+ self.last_signal_events.get("peak_auto_prominence_threshold", np.full_like(peak_heights, np.nan)),
+ float,
+ )
peak_widths = np.asarray(self.last_signal_events.get("peak_widths_sec", np.array([], float)), float)
+ peak_auc = np.asarray(self.last_signal_events.get("peak_auc", np.array([], float)), float)
file_ids = self.last_signal_events.get("file_ids", [])
if peak_times.size == 0:
return
@@ -5504,7 +6125,22 @@ def _export_signal_events_csv(self) -> None:
import csv
with open(out_path, "w", newline="") as f:
w = csv.writer(f)
- w.writerow(["peak_time_sec", "height", "prominence", "width_sec", "file_id"])
+ w.writerow(
+ [
+ "peak_time_sec",
+ "height",
+ "prominence",
+ "width_sec",
+ "auc",
+ "trace_value",
+ "signal_height",
+ "normalized_prominence",
+ "baseline_prominence_scale",
+ "mad_noise_sigma",
+ "auto_prominence_threshold",
+ "file_id",
+ ]
+ )
for i in range(peak_times.size):
fid = file_ids[i] if isinstance(file_ids, list) and i < len(file_ids) else ""
w.writerow(
@@ -5513,6 +6149,13 @@ def _export_signal_events_csv(self) -> None:
float(peak_heights[i]) if i < peak_heights.size else np.nan,
float(peak_proms[i]) if i < peak_proms.size else np.nan,
float(peak_widths[i]) if i < peak_widths.size else np.nan,
+ float(peak_auc[i]) if i < peak_auc.size else np.nan,
+ float(peak_trace_values[i]) if i < peak_trace_values.size else np.nan,
+ float(peak_signal_heights[i]) if i < peak_signal_heights.size else np.nan,
+ float(peak_norm_proms[i]) if i < peak_norm_proms.size else np.nan,
+ float(peak_norm_scales[i]) if i < peak_norm_scales.size else np.nan,
+ float(peak_mad_sigmas[i]) if i < peak_mad_sigmas.size else np.nan,
+ float(peak_auto_prominence[i]) if i < peak_auto_prominence.size else np.nan,
fid,
]
)
@@ -5805,6 +6448,7 @@ def _compute_psth(self) -> None:
if hasattr(self, "lbl_global_metrics"):
self.lbl_global_metrics.setText("Global metrics: -")
self._update_status_strip()
+ self._sync_temporal_modeling_context()
return
# Update trace preview each time (also updates event lines)
@@ -5893,6 +6537,7 @@ def _compute_psth(self) -> None:
self._last_event_rows = []
self._last_durations = np.array([], float)
self._update_status_strip()
+ self._sync_temporal_modeling_context()
return
mat_events = np.vstack(mats)
@@ -5906,6 +6551,7 @@ def _compute_psth(self) -> None:
self._last_event_rows = []
self._last_durations = np.array([], float)
self._update_status_strip()
+ self._sync_temporal_modeling_context()
return
mat_display = self._group_mat
display_labels = self._group_labels
@@ -5953,13 +6599,24 @@ def _compute_psth(self) -> None:
self._update_metric_regions()
self._update_status_strip()
self._save_settings()
+ self._sync_temporal_modeling_context()
except Exception as e:
self.statusUpdate.emit(f"Postprocessing error: {e}", 5000)
self._update_status_strip()
def _render_heatmap(self, mat: np.ndarray, tvec: np.ndarray, labels: Optional[List[str]] = None) -> None:
if mat.size == 0:
- self.img.setImage(np.zeros((1, 1)))
+ self._suppress_heatmap_level_store = True
+ try:
+ self.img.setImage(np.zeros((1, 1)), autoLevels=False)
+ self.img.setLevels([0.0, 1.0])
+ if hasattr(self, "heat_lut") and getattr(self.heat_lut, "item", None) is not None:
+ try:
+ self.heat_lut.item.setLevels(0.0, 1.0)
+ except TypeError:
+ self.heat_lut.item.setLevels((0.0, 1.0))
+ finally:
+ self._suppress_heatmap_level_store = False
self.heat_zero_line.setVisible(False)
return
@@ -5974,12 +6631,38 @@ def _render_heatmap(self, mat: np.ndarray, tvec: np.ndarray, labels: Optional[Li
self.heat_lut.item.gradient.setColorMap(cmap)
except Exception:
pass
+ finite = img[np.isfinite(img)]
+ if finite.size:
+ lo = float(np.nanmin(finite))
+ hi = float(np.nanmax(finite))
+ else:
+ lo, hi = 0.0, 1.0
+ manual_levels = bool(self._style.get("heatmap_levels_manual", False))
+ if not manual_levels:
+ self._style["heatmap_min"] = None
+ self._style["heatmap_max"] = None
hmin = self._style.get("heatmap_min", None)
hmax = self._style.get("heatmap_max", None)
- # set image (auto-level)
- self.img.setImage(img, autoLevels=True)
- if hmin is not None and hmax is not None:
- self.img.setLevels([float(hmin), float(hmax)])
+ if manual_levels and hmin is not None and hmax is not None:
+ try:
+ lo = float(hmin)
+ hi = float(hmax)
+ except Exception:
+ pass
+ if (not np.isfinite(lo)) or (not np.isfinite(hi)) or hi <= lo:
+ center = lo if np.isfinite(lo) else 0.0
+ lo, hi = center - 0.5, center + 0.5
+ self._suppress_heatmap_level_store = True
+ try:
+ self.img.setImage(img, autoLevels=False)
+ self.img.setLevels([lo, hi])
+ if hasattr(self, "heat_lut") and getattr(self.heat_lut, "item", None) is not None:
+ try:
+ self.heat_lut.item.setLevels(lo, hi)
+ except TypeError:
+ self.heat_lut.item.setLevels((lo, hi))
+ finally:
+ self._suppress_heatmap_level_store = False
# Map image to time axis using a rect (avoids scale() incompatibilities)
x0 = float(tvec[0]) if tvec.size else 0.0
@@ -6094,6 +6777,7 @@ def _render_metrics(self, mat: np.ndarray, tvec: np.ndarray) -> None:
self.metrics_scatter_post.setData([], [])
self._set_error_bar(self.metrics_err_pre, 0.0, 0.0, 0.0)
self._set_error_bar(self.metrics_err_post, 1.0, 0.0, 0.0)
+ self.metrics_p_text.setVisible(False)
self._last_metrics = None
return
metric = self.combo_metric.currentText()
@@ -6118,6 +6802,7 @@ def _window_vals(a: float, b: float) -> np.ndarray:
self.metrics_scatter_post.setData([], [])
self._set_error_bar(self.metrics_err_pre, 0.0, 0.0, 0.0)
self._set_error_bar(self.metrics_err_post, 1.0, 0.0, 0.0)
+ self.metrics_p_text.setVisible(False)
self._last_metrics = None
return
@@ -6146,10 +6831,24 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray:
pair_mask = np.isfinite(pre_vals_all) & np.isfinite(post_vals_all)
pre_pair = pre_vals_all[pair_mask]
post_pair = post_vals_all[pair_mask]
+ p_value = np.nan
+ n_pair = int(min(pre_pair.size, post_pair.size))
if pre_pair.size and post_pair.size:
- n_pair = int(min(pre_pair.size, post_pair.size))
pre_pair = pre_pair[:n_pair]
post_pair = post_pair[:n_pair]
+ if n_pair >= 2:
+ diffs = post_pair - pre_pair
+ diffs = diffs[np.isfinite(diffs)]
+ if diffs.size >= 2:
+ if float(np.nanstd(diffs, ddof=1)) == 0.0:
+ p_value = 1.0 if float(np.nanmean(diffs)) == 0.0 else 0.0
+ else:
+ try:
+ from scipy import stats
+ result = stats.ttest_rel(pre_pair, post_pair, nan_policy="omit")
+ p_value = float(result.pvalue)
+ except Exception:
+ p_value = np.nan
# Build segmented polyline: (0, pre_i) -> (1, post_i), NaN separator.
x_line = np.empty(n_pair * 3, dtype=float)
y_line = np.empty(n_pair * 3, dtype=float)
@@ -6207,6 +6906,17 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray:
if ymin == ymax:
ymax = ymin + 1.0
self.plot_metrics.setYRange(ymin, ymax, padding=0.2)
+ span = float(ymax - ymin) if np.isfinite(ymax - ymin) and ymax != ymin else 1.0
+ if n_pair >= 2:
+ if np.isfinite(p_value):
+ p_label = "paired p < 1e-4" if p_value < 1e-4 else f"paired p = {p_value:.4g}"
+ else:
+ p_label = "paired p = n/a"
+ self.metrics_p_text.setText(p_label)
+ self.metrics_p_text.setPos(0.5, ymax + 0.14 * span)
+ self.metrics_p_text.setVisible(True)
+ else:
+ self.metrics_p_text.setVisible(False)
self._last_metrics = {
"pre": pre_mean,
"post": post_mean,
@@ -6214,6 +6924,8 @@ def _metric_vals(win: np.ndarray, duration: float) -> np.ndarray:
"post_sem": post_sem,
"pre_n": float(pre_n),
"post_n": float(post_n),
+ "paired_n": float(n_pair),
+ "paired_p": float(p_value) if np.isfinite(p_value) else np.nan,
"metric": metric,
}
@@ -6492,7 +7204,7 @@ def _apply_plot_style(self) -> None:
pass
def _on_heatmap_levels_changed(self) -> None:
- if self._is_restoring_settings:
+ if self._is_restoring_settings or self._suppress_heatmap_level_store:
return
if not hasattr(self, "heat_lut") or getattr(self.heat_lut, "item", None) is None:
return
@@ -6507,6 +7219,7 @@ def _on_heatmap_levels_changed(self) -> None:
if np.isfinite(lo) and np.isfinite(hi) and hi > lo:
self._style["heatmap_min"] = lo
self._style["heatmap_max"] = hi
+ self._style["heatmap_levels_manual"] = True
self._queue_settings_save()
def _open_style_dialog(self) -> None:
@@ -6515,6 +7228,7 @@ def _open_style_dialog(self) -> None:
return
self._style = dlg.get_style()
self._apply_plot_style()
+ self._record_history_change()
self._render_heatmap(self._last_mat if self._last_mat is not None else np.zeros((1, 1)), self._last_tvec if self._last_tvec is not None else np.array([0.0, 1.0]))
self._render_spatial_heatmap(
self._last_spatial_occupancy_map,
@@ -6657,7 +7371,13 @@ def _save_signal_events_h5(self, parent: h5py.Group) -> None:
"peak_times_sec",
"peak_indices",
"peak_heights",
+ "peak_signal_heights",
+ "peak_trace_values",
"peak_prominences",
+ "peak_normalized_prominences",
+ "peak_baseline_prominence_scale",
+ "peak_mad_noise_sigma",
+ "peak_auto_prominence_threshold",
"peak_widths_sec",
"peak_auc",
):
@@ -6665,6 +7385,16 @@ def _save_signal_events_h5(self, parent: h5py.Group) -> None:
self._write_h5_str_list(group, "file_ids", [str(v) for v in self.last_signal_events.get("file_ids", []) or []])
self._write_h5_json(group, "derived_metrics_json", dict(self.last_signal_events.get("derived_metrics", {}) or {}))
self._write_h5_json(group, "params_json", dict(self.last_signal_events.get("params", {}) or {}))
+ self._write_h5_json_any(
+ group,
+ "normalization_by_file_json",
+ dict(self.last_signal_events.get("normalization_by_file", {}) or {}),
+ )
+ self._write_h5_json_any(
+ group,
+ "mad_threshold_by_file_json",
+ dict(self.last_signal_events.get("mad_threshold_by_file", {}) or {}),
+ )
def _load_signal_events_h5(self, parent: Optional[h5py.Group]) -> Optional[Dict[str, object]]:
if parent is None:
@@ -6683,12 +7413,20 @@ def _num(name: str) -> np.ndarray:
"peak_times_sec": _num("peak_times_sec"),
"peak_indices": _num("peak_indices"),
"peak_heights": _num("peak_heights"),
+ "peak_signal_heights": _num("peak_signal_heights"),
+ "peak_trace_values": _num("peak_trace_values"),
"peak_prominences": _num("peak_prominences"),
+ "peak_normalized_prominences": _num("peak_normalized_prominences"),
+ "peak_baseline_prominence_scale": _num("peak_baseline_prominence_scale"),
+ "peak_mad_noise_sigma": _num("peak_mad_noise_sigma"),
+ "peak_auto_prominence_threshold": _num("peak_auto_prominence_threshold"),
"peak_widths_sec": _num("peak_widths_sec"),
"peak_auc": _num("peak_auc"),
"file_ids": self._read_h5_str_list(group, "file_ids"),
"derived_metrics": self._read_h5_json(group, "derived_metrics_json"),
"params": self._read_h5_json(group, "params_json"),
+ "normalization_by_file": self._read_h5_json_any(group, "normalization_by_file_json", {}),
+ "mad_threshold_by_file": self._read_h5_json_any(group, "mad_threshold_by_file_json", {}),
}
return out
@@ -6795,6 +7533,13 @@ def _restore_cached_analysis_outputs(self, payload: Dict[str, object]) -> None:
self.last_behavior_analysis = behavior_analysis
self._render_behavior_analysis_outputs()
+ temporal_state = payload.get("temporal_modeling")
+ if isinstance(temporal_state, dict) and temporal_state and hasattr(self, "section_temporal"):
+ try:
+ self.section_temporal.deserialize_state(temporal_state)
+ except Exception:
+ _LOG.debug("Could not restore temporal modeling state", exc_info=True)
+
def _save_project_h5(self, path: str) -> None:
with h5py.File(path, "w") as f:
f.attrs["project_type"] = "pyber_postprocessing_project"
@@ -6906,6 +7651,13 @@ def _save_project_h5(self, path: str) -> None:
analysis_group = f.create_group("analysis")
self._save_signal_events_h5(analysis_group)
self._save_behavior_analysis_h5(analysis_group)
+ try:
+ if hasattr(self, "section_temporal") and self.section_temporal is not None:
+ state = self.section_temporal.serialize_state()
+ if state:
+ self._write_h5_json_any(analysis_group, "temporal_modeling_json", state)
+ except Exception:
+ _LOG.debug("Could not save temporal modeling state to project", exc_info=True)
def _load_project_h5(self, path: str) -> Dict[str, object]:
payload: Dict[str, object] = {
@@ -7064,6 +7816,10 @@ def _aligned(values: Optional[np.ndarray], fill_nan: bool = True) -> np.ndarray:
if isinstance(analysis_group, h5py.Group):
payload["signal_events"] = self._load_signal_events_h5(analysis_group)
payload["behavior_analysis"] = self._load_behavior_analysis_h5(analysis_group)
+ try:
+ payload["temporal_modeling"] = self._read_h5_json_any(analysis_group, "temporal_modeling_json", {})
+ except Exception:
+ payload["temporal_modeling"] = {}
payload["processed"] = loaded_processed
payload["behavior_sources"] = loaded_behavior
@@ -7188,10 +7944,7 @@ def _confirm_discard_current_project(self) -> bool:
)
return ask == QtWidgets.QMessageBox.StandardButton.Yes
- def _new_project(self) -> None:
- if not self._confirm_discard_current_project():
- return
-
+ def _reset_project_state(self) -> None:
was_restoring = self._is_restoring_settings
self._is_restoring_settings = True
try:
@@ -7199,6 +7952,7 @@ def _new_project(self) -> None:
self._processed = []
self._behavior_sources = {}
self._pending_project_recompute_from_current = False
+ self._dio_cache.clear()
self.lbl_group.setText("(none)")
self.lbl_beh.setText("(none)")
self.lbl_behavior_msg.setText("")
@@ -7220,6 +7974,18 @@ def _new_project(self) -> None:
self._project_recovered_from_autosave = False
self._clear_project_autosave_cache(delete_file=True)
self._update_status_strip()
+
+ def reset_for_new_preprocessing_project(self) -> None:
+ self._reset_project_state()
+ self._reset_history_snapshot()
+ self.statusUpdate.emit("Cleared postprocessing project state.", 5000)
+
+ def _new_project(self) -> None:
+ if not self._confirm_discard_current_project():
+ return
+
+ self._reset_project_state()
+ self._reset_history_snapshot()
self.statusUpdate.emit("Started a new postprocessing project.", 5000)
def _import_project_source_paths(self, recent_paths: Dict[str, object]) -> bool:
@@ -7343,8 +8109,116 @@ def _load_project_from_path(self, path: str, from_autosave: bool = False) -> boo
self.statusUpdate.emit("Recovered autosaved postprocessing project.", 5000)
else:
self.statusUpdate.emit(f"Project loaded: {os.path.basename(path)}", 5000)
+ self._reset_history_snapshot()
return True
+ def _default_settings_payload(self) -> Dict[str, object]:
+ return {
+ "align": "Behavior (CSV/XLSX)",
+ "dio_channel": self.combo_dio.currentText(),
+ "dio_polarity": "Event high (0->1)",
+ "dio_align": "Align to onset",
+ "behavior_file_type": "Binary states (time + 0/1 columns)",
+ "behavior_time_fps": 30.0,
+ "behavior": self.combo_behavior_name.currentText(),
+ "behavior_align": self.combo_behavior_align.currentText(),
+ "behavior_from": self.combo_behavior_from.currentText(),
+ "behavior_to": self.combo_behavior_to.currentText(),
+ "transition_gap": 1.0,
+ "window_pre": 2.0,
+ "window_post": 5.0,
+ "baseline_start": -1.0,
+ "baseline_end": 0.0,
+ "resample": 50.0,
+ "smooth": 0.0,
+ "filter_enabled": True,
+ "event_start": 1,
+ "event_end": 0,
+ "group_window_s": 0.0,
+ "dur_min": 0.0,
+ "dur_max": 0.0,
+ "metrics_enabled": True,
+ "metric": "AUC",
+ "metric_pre0": -1.0,
+ "metric_pre1": 0.0,
+ "metric_post0": 0.0,
+ "metric_post1": 1.0,
+ "global_metrics_enabled": True,
+ "global_start": 0.0,
+ "global_end": 0.0,
+ "global_amp": True,
+ "global_freq": True,
+ "view_layout": "Standard",
+ "visual_mode": 0,
+ "individual_file": self.combo_individual_file.currentText().strip(),
+ "signal_source": "Use processed output trace (loaded file)",
+ "signal_scope": "Per file",
+ "signal_file": self.combo_signal_file.currentText(),
+ "signal_method": "SciPy find_peaks",
+ "signal_prominence": 0.5,
+ "signal_auto_mad": False,
+ "signal_mad_multiplier": 5.0,
+ "signal_height": 0.0,
+ "signal_distance": 0.5,
+ "signal_smooth": 0.0,
+ "signal_baseline": "Use trace as-is",
+ "signal_baseline_window": 10.0,
+ "signal_norm_prominence": False,
+ "signal_rate_bin": 60.0,
+ "signal_auc_window": 0.5,
+ "signal_overlay": True,
+ "signal_noise_overlay": False,
+ "behavior_analysis_name": self.combo_behavior_analysis.currentText(),
+ "behavior_analysis_bin": 30.0,
+ "behavior_analysis_aligned": False,
+ "spatial_x": self.combo_spatial_x.currentText(),
+ "spatial_y": self.combo_spatial_y.currentText(),
+ "spatial_bins_x": 64,
+ "spatial_bins_y": 64,
+ "spatial_weight": "Occupancy (samples)",
+ "spatial_clip": True,
+ "spatial_clip_low": 1.0,
+ "spatial_clip_high": 99.0,
+ "spatial_time_filter": False,
+ "spatial_time_min": 0.0,
+ "spatial_time_max": 0.0,
+ "spatial_smooth": 0.0,
+ "spatial_activity_mode": "Mean z-score/bin (occupancy normalized)",
+ "spatial_activity_norm": True,
+ "spatial_log": False,
+ "spatial_invert_y": False,
+ "style": {
+ "trace": (90, 190, 255),
+ "behavior": (220, 180, 80),
+ "avg": (90, 190, 255),
+ "sem_edge": (152, 201, 143),
+ "sem_fill": (188, 230, 178, 96),
+ "plot_bg": (248, 250, 255) if self._app_theme_mode == "light" else (36, 42, 52),
+ "grid_enabled": True,
+ "grid_alpha": 0.25,
+ "heatmap_cmap": "viridis",
+ "heatmap_min": None,
+ "heatmap_max": None,
+ "heatmap_levels_manual": False,
+ },
+ }
+
+ def _reset_config_defaults(self) -> None:
+ previous_restoring = self._is_restoring_settings
+ previous_history_restoring = self._history_restoring
+ self._is_restoring_settings = True
+ self._history_restoring = True
+ try:
+ self._apply_settings(self._default_settings_payload())
+ finally:
+ self._is_restoring_settings = previous_restoring
+ self._history_restoring = previous_history_restoring
+ self._record_history_change()
+ self._rerender_visual_from_cache()
+ self._compute_spatial_heatmap()
+ self._save_settings()
+ self.statusUpdate.emit("Reset postprocessing parameters to defaults.", 3000)
+
def _save_config_file(self) -> None:
start_dir = self._settings.value("postprocess_last_dir", os.getcwd(), type=str)
path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "Save config", os.path.join(start_dir, "postprocess_config.json"), "JSON (*.json)")
@@ -7412,19 +8286,25 @@ def _collect_settings(self) -> Dict[str, object]:
"global_amp": self.cb_global_amp.isChecked(),
"global_freq": self.cb_global_freq.isChecked(),
"view_layout": self.combo_view_layout.currentText(),
+ "visual_mode": int(self.tab_visual_mode.currentIndex()),
+ "individual_file": self.combo_individual_file.currentText().strip(),
"signal_source": self.combo_signal_source.currentText(),
"signal_scope": self.combo_signal_scope.currentText(),
"signal_file": self.combo_signal_file.currentText(),
"signal_method": self.combo_signal_method.currentText(),
"signal_prominence": float(self.spin_peak_prominence.value()),
+ "signal_auto_mad": self.cb_peak_auto_mad.isChecked(),
+ "signal_mad_multiplier": float(self.spin_peak_mad_multiplier.value()),
"signal_height": float(self.spin_peak_height.value()),
"signal_distance": float(self.spin_peak_distance.value()),
"signal_smooth": float(self.spin_peak_smooth.value()),
"signal_baseline": self.combo_peak_baseline.currentText(),
"signal_baseline_window": float(self.spin_peak_baseline_window.value()),
+ "signal_norm_prominence": self.cb_peak_norm_prominence.isChecked(),
"signal_rate_bin": float(self.spin_peak_rate_bin.value()),
"signal_auc_window": float(self.spin_peak_auc_window.value()),
"signal_overlay": self.cb_peak_overlay.isChecked(),
+ "signal_noise_overlay": self.cb_peak_noise_overlay.isChecked(),
"behavior_analysis_name": self.combo_behavior_analysis.currentText(),
"behavior_analysis_bin": float(self.spin_behavior_bin.value()),
"behavior_analysis_aligned": self.cb_behavior_aligned.isChecked(),
@@ -7513,6 +8393,19 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None:
if "global_freq" in data:
self.cb_global_freq.setChecked(bool(data["global_freq"]))
_set_combo(self.combo_view_layout, data.get("view_layout"))
+ if "visual_mode" in data:
+ try:
+ idx = int(data.get("visual_mode") or 0)
+ if 0 <= idx < self.tab_visual_mode.count():
+ self.tab_visual_mode.setCurrentIndex(idx)
+ except Exception:
+ pass
+ if "individual_file" in data:
+ file_id = str(data.get("individual_file") or "").strip()
+ if file_id:
+ idx = self.combo_individual_file.findText(file_id)
+ if idx >= 0:
+ self.combo_individual_file.setCurrentIndex(idx)
_set_combo(self.combo_signal_source, data.get("signal_source"))
_set_combo(self.combo_signal_scope, data.get("signal_scope"))
self._refresh_signal_file_combo()
@@ -7520,6 +8413,11 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None:
_set_combo(self.combo_signal_method, data.get("signal_method"))
if "signal_prominence" in data:
self.spin_peak_prominence.setValue(float(data["signal_prominence"]))
+ if "signal_auto_mad" in data:
+ self.cb_peak_auto_mad.setChecked(bool(data["signal_auto_mad"]))
+ if "signal_mad_multiplier" in data:
+ self.spin_peak_mad_multiplier.setValue(float(data["signal_mad_multiplier"]))
+ self._update_peak_auto_mad_enabled(queue=False)
if "signal_height" in data:
self.spin_peak_height.setValue(float(data["signal_height"]))
if "signal_distance" in data:
@@ -7529,12 +8427,16 @@ def _set_combo(combo: QtWidgets.QComboBox, val: object) -> None:
_set_combo(self.combo_peak_baseline, data.get("signal_baseline"))
if "signal_baseline_window" in data:
self.spin_peak_baseline_window.setValue(float(data["signal_baseline_window"]))
+ if "signal_norm_prominence" in data:
+ self.cb_peak_norm_prominence.setChecked(bool(data["signal_norm_prominence"]))
if "signal_rate_bin" in data:
self.spin_peak_rate_bin.setValue(float(data["signal_rate_bin"]))
if "signal_auc_window" in data:
self.spin_peak_auc_window.setValue(float(data["signal_auc_window"]))
if "signal_overlay" in data:
self.cb_peak_overlay.setChecked(bool(data["signal_overlay"]))
+ if "signal_noise_overlay" in data:
+ self.cb_peak_noise_overlay.setChecked(bool(data["signal_noise_overlay"]))
_set_combo(self.combo_behavior_analysis, data.get("behavior_analysis_name"))
if "behavior_analysis_bin" in data:
self.spin_behavior_bin.setValue(float(data["behavior_analysis_bin"]))
@@ -9292,8 +10194,10 @@ def _pick_color(self, key: str, with_alpha: bool = False) -> None:
def get_style(self) -> Dict[str, object]:
self._style["heatmap_cmap"] = self.combo_cmap.currentText()
- self._style["heatmap_min"] = float(self.spin_hmin.value()) if self.spin_hmin.value() != 0.0 else None
- self._style["heatmap_max"] = float(self.spin_hmax.value()) if self.spin_hmax.value() != 0.0 else None
+ manual_levels = self.spin_hmin.value() != 0.0 or self.spin_hmax.value() != 0.0
+ self._style["heatmap_min"] = float(self.spin_hmin.value()) if manual_levels else None
+ self._style["heatmap_max"] = float(self.spin_hmax.value()) if manual_levels else None
+ self._style["heatmap_levels_manual"] = bool(manual_levels)
self._style["grid_enabled"] = bool(self.cb_grid.isChecked())
self._style["grid_alpha"] = float(self.spin_grid_alpha.value())
self._style.setdefault("plot_bg", (36, 42, 52))
diff --git a/pyBer/gui_preprocessing.py b/pyBer/gui_preprocessing.py
index c2451b7..d845f42 100644
--- a/pyBer/gui_preprocessing.py
+++ b/pyBer/gui_preprocessing.py
@@ -17,6 +17,7 @@
BASELINE_METHODS,
REFERENCE_FIT_METHODS,
SMOOTHING_METHODS,
+ ARTIFACT_HANDLING_MODES,
)
@@ -74,6 +75,12 @@ def _compact_combo(combo: QtWidgets.QComboBox, min_chars: int = 6) -> None:
"Raw signal (465)": "output = filtered/resampled 465 signal",
}
+_DFF_OUTPUT_MODES = {
+ "dFF (non motion corrected)",
+ "dFF (motion corrected via subtraction)",
+ "dFF (motion corrected with fitted ref)",
+}
+
def _system_locale() -> QtCore.QLocale:
return QtCore.QLocale.system()
@@ -1339,6 +1346,13 @@ def _build_help_texts(self) -> Dict[str, str]:
"- Global MAD: single threshold for the full trace (fast, stable).\n"
"- Adaptive MAD: threshold computed in sliding windows (handles drift)."
),
+ "artifact_handling": (
+ "How detected/manual artifact windows affect processing:\n"
+ "- Interpolate: replace windows by linear interpolation before filtering.\n"
+ "- Cut: remove artifact samples from the final processed trace.\n"
+ "- Strong local low-pass: replace only artifact-window samples with a strongly smoothed trace.\n"
+ "- Do nothing: keep the samples unchanged; overlays still show detected artifacts."
+ ),
"mad_k": (
"MAD threshold (k) scales the derivative threshold.\n"
"Higher k = fewer artifacts flagged; lower k = more sensitive."
@@ -1458,6 +1472,10 @@ def _build_help_texts(self) -> Dict[str, str]:
"Choose one or more DIO channels to export.\n"
"If none are checked, export uses the current overlay trigger."
),
+ "auto_export": (
+ "When enabled, the Export button writes selected files directly beside their source raw data.\n"
+ "All available analog channels are exported with the current analysis parameters."
+ ),
}
def _show_help(self, key: str, title: str) -> None:
@@ -1511,6 +1529,9 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self.combo_artifact = QtWidgets.QComboBox()
self.combo_artifact.addItems(["Global MAD (dx)", "Adaptive MAD (windowed)"])
_compact_combo(self.combo_artifact, min_chars=6)
+ self.combo_artifact_handling = QtWidgets.QComboBox()
+ self.combo_artifact_handling.addItems(ARTIFACT_HANDLING_MODES)
+ _compact_combo(self.combo_artifact_handling, min_chars=8)
self.spin_mad = mk_dspin()
self.spin_mad.setRange(1.0, 50.0)
self.spin_mad.setValue(8.0)
@@ -1527,6 +1548,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
art_form.addRow(self.cb_artifact)
art_form.addRow(self.cb_show_artifact_overlay)
art_form.addRow(self._label_with_help("Method", "artifact_mode"), self.combo_artifact)
+ art_form.addRow(self._label_with_help("Handling", "artifact_handling"), self.combo_artifact_handling)
art_form.addRow(self._label_with_help("MAD threshold (k)", "mad_k"), self.spin_mad)
art_form.addRow(self._label_with_help("Adaptive window (s)", "adaptive_window_s"), self.spin_adapt_win)
art_form.addRow(self._label_with_help("Artifact pad (s)", "artifact_pad_s"), self.spin_pad)
@@ -1725,6 +1747,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self.btn_advanced = QtWidgets.QPushButton("Cutting / Sectioning")
self.btn_save_config = QtWidgets.QPushButton("Save config")
self.btn_load_config = QtWidgets.QPushButton("Load config")
+ self.btn_reset_defaults = QtWidgets.QPushButton("Reset defaults")
for b in (
self.btn_artifacts_panel,
self.btn_qc,
@@ -1734,12 +1757,14 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self.btn_advanced,
self.btn_save_config,
self.btn_load_config,
+ self.btn_reset_defaults,
):
b.setProperty("class", "compactSmall")
b.setSizePolicy(QtWidgets.QSizePolicy.Policy.Expanding, QtWidgets.QSizePolicy.Policy.Fixed)
self.btn_export.setProperty("class", "compactPrimarySmall")
self.btn_save_config.clicked.connect(self._save_config)
self.btn_load_config.clicked.connect(self._load_config)
+ self.btn_reset_defaults.clicked.connect(self._reset_defaults)
self.btn_artifacts_panel.clicked.connect(self.artifactsRequested.emit)
self.btn_qc.clicked.connect(self.qcRequested.emit)
self.btn_qc_batch.clicked.connect(self.batchQcRequested.emit)
@@ -1747,6 +1772,11 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self.btn_metadata.clicked.connect(self.metadataRequested.emit)
self.btn_advanced.clicked.connect(self.advancedOptionsRequested.emit)
+ self.chk_auto_export = QtWidgets.QCheckBox("Use source folder + all channels")
+ self.chk_auto_export.setChecked(False)
+ self.chk_auto_export.setToolTip(
+ "Skip the folder picker and export all analog channels beside each source file."
+ )
self.chk_export_raw = QtWidgets.QCheckBox("Raw 465")
self.chk_export_raw.setChecked(True)
self.chk_export_iso = QtWidgets.QCheckBox("Isobestic 405")
@@ -1765,9 +1795,20 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self.list_export_dio = CheckableListWidget()
self.list_export_dio.setMaximumHeight(84)
self.list_export_dio.setMinimumHeight(54)
+ self.list_export_outputs = CheckableListWidget()
+ self.list_export_outputs.setMaximumHeight(116)
+ self.list_export_outputs.setMinimumHeight(72)
+ self.list_export_outputs.setToolTip("Select one or more processed output traces to write during export.")
self._pending_export_channel_names: List[str] = []
self._pending_export_trigger_names: List[str] = []
+ self._pending_export_output_modes: List[str] = [self.combo_output.currentText()]
+ self._export_outputs_follow_current_mode = True
+ self.list_export_outputs.set_items(
+ [(mode, mode) for mode in OUTPUT_MODES],
+ checked_values=self._pending_export_output_modes,
+ )
self.list_export_dio.setEnabled(self.chk_export_dio.isChecked())
+ self.list_export_outputs.setEnabled(self.chk_export_output.isChecked())
self.export_options_group = QtWidgets.QGroupBox("Export fields")
export_form = QtWidgets.QFormLayout(self.export_options_group)
@@ -1784,6 +1825,8 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
export_checks.addWidget(self.chk_export_baseline_sig, 2, 0)
export_checks.addWidget(self.chk_export_baseline_ref, 2, 1)
export_form.addRow(export_checks)
+ export_form.addRow(self._label_with_help("Auto export", "auto_export"), self.chk_auto_export)
+ export_form.addRow(self._label_with_help("Output traces", "output_mode"), self.list_export_outputs)
export_form.addRow(self._label_with_help("AN channels", "export_analog_channels"), self.list_export_channels)
export_form.addRow(self._label_with_help("DIO channels", "export_dio_channels"), self.list_export_dio)
@@ -1799,6 +1842,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
qc_grid.addWidget(self.btn_qc_batch, 2, 1)
qc_grid.addWidget(self.btn_metadata, 3, 0)
qc_grid.addWidget(self.btn_save_config, 3, 1)
+ qc_grid.addWidget(self.btn_reset_defaults, 4, 0)
qc_grid.addWidget(self.btn_load_config, 4, 1)
qc_grid.setColumnStretch(0, 1)
qc_grid.setColumnStretch(1, 1)
@@ -1837,6 +1881,7 @@ def mk_spin(minw=60) -> QtWidgets.QSpinBox:
self._update_output_definition()
self._update_output_controls()
self._update_smoothing_controls(emit_signal=False)
+ self._update_auto_export_controls()
self._update_section_summaries()
def _update_artifact_enabled(self) -> None:
@@ -1922,6 +1967,7 @@ def _fmt_num(self, value: float, decimals: int = 3) -> str:
return text if text else "0"
def _update_section_summaries(self) -> None:
+ handling = self.combo_artifact_handling.currentText()
if self.cb_artifact.isChecked():
mode = self.combo_artifact.currentText()
method = "Adaptive MAD" if mode.startswith("Adaptive") else "Global MAD"
@@ -1929,15 +1975,15 @@ def _update_section_summaries(self) -> None:
summary = (
f"{method}, k={self._fmt_num(self.spin_mad.value(), 2)}, "
f"window={self._fmt_num(self.spin_adapt_win.value(), 2)}s, "
- f"pad={self._fmt_num(self.spin_pad.value(), 2)}s"
+ f"pad={self._fmt_num(self.spin_pad.value(), 2)}s, {handling}"
)
else:
summary = (
f"{method}, k={self._fmt_num(self.spin_mad.value(), 2)}, "
- f"pad={self._fmt_num(self.spin_pad.value(), 2)}s"
+ f"pad={self._fmt_num(self.spin_pad.value(), 2)}s, {handling}"
)
else:
- summary = "Off"
+ summary = f"Detection off, {handling}"
self.card_artifacts.set_summary(summary)
if self.cb_filtering.isChecked():
@@ -1979,6 +2025,7 @@ def emit_noargs(*_args) -> None:
widgets = (
self.combo_artifact,
+ self.combo_artifact_handling,
self.spin_mad,
self.spin_adapt_win,
self.spin_pad,
@@ -2016,6 +2063,7 @@ def emit_noargs(*_args) -> None:
self.spin_lam_y.valueChanged.connect(lambda *_: self._update_lambda_preview())
self.combo_output.currentIndexChanged.connect(lambda *_: self._update_output_definition())
self.combo_output.currentIndexChanged.connect(lambda *_: self._update_output_controls())
+ self.combo_output.currentIndexChanged.connect(lambda *_: self._sync_export_outputs_to_current_mode())
self.combo_ref_fit.currentIndexChanged.connect(lambda *_: self._update_output_controls())
# Keep collapsed card summaries synchronized with current values.
@@ -2031,11 +2079,16 @@ def emit_noargs(*_args) -> None:
self.combo_smoothing.currentIndexChanged.connect(lambda *_: self._update_smoothing_controls())
self.cb_invert.stateChanged.connect(emit_noargs)
self.cb_show_artifact_overlay.toggled.connect(lambda v: self.artifactOverlayToggled.emit(bool(v)))
+ self.chk_auto_export.toggled.connect(self._update_auto_export_controls)
+ self.chk_auto_export.toggled.connect(emit_noargs)
self.chk_export_dio.toggled.connect(self.list_export_dio.setEnabled)
+ self.chk_export_output.toggled.connect(self.list_export_outputs.setEnabled)
self.list_export_channels.changed.connect(self._on_export_channel_selection_changed)
self.list_export_channels.changed.connect(emit_noargs)
self.list_export_dio.changed.connect(self._on_export_trigger_selection_changed)
self.list_export_dio.changed.connect(emit_noargs)
+ self.list_export_outputs.changed.connect(self._on_export_output_selection_changed)
+ self.list_export_outputs.changed.connect(emit_noargs)
for cb in (
self.chk_export_raw,
self.chk_export_iso,
@@ -2054,6 +2107,9 @@ def _toggle_advanced_baseline(self) -> None:
"Hide advanced baseline options" if not is_visible else "Show advanced baseline options"
)
+ def _update_auto_export_controls(self, *_args) -> None:
+ self.list_export_channels.setEnabled(not self.auto_export_enabled())
+
def set_config_state_hooks(
self,
exporter: Optional[Callable[[], Dict[str, object]]],
@@ -2070,16 +2126,23 @@ def export_selection(self) -> ExportSelection:
dio=self.chk_export_dio.isChecked(),
baseline_sig=self.chk_export_baseline_sig.isChecked(),
baseline_ref=self.chk_export_baseline_ref.isChecked(),
+ output_modes=self.export_output_modes(),
)
def export_selection_summary(self) -> str:
selection = self.export_selection()
parts: List[str] = []
- chans = self.export_channel_names()
- if chans:
- parts.append(f"ANx{len(chans)}")
+ if self.auto_export_enabled():
+ n_channels = self.list_export_channels.count()
+ parts.append(f"ANx{n_channels}" if n_channels else "all AN")
+ parts.append("source folder")
+ else:
+ chans = self.export_channel_names()
+ if chans:
+ parts.append(f"ANx{len(chans)}")
if selection.output:
- parts.append("output")
+ modes = self.export_output_modes()
+ parts.append(f"outputsx{len(modes)}" if len(modes) > 1 else "output")
if selection.dio:
dio_names = self.export_trigger_names()
parts.append(f"DIOx{len(dio_names)}" if dio_names else "DIO")
@@ -2106,6 +2169,10 @@ def set_export_selection(self, selection: ExportSelection) -> None:
self.chk_export_dio.setChecked(bool(selection.dio))
self.chk_export_baseline_sig.setChecked(bool(selection.baseline_sig))
self.chk_export_baseline_ref.setChecked(bool(selection.baseline_ref))
+ if selection.output_modes:
+ self.set_export_output_modes(selection.output_modes, follow_current=False)
+ else:
+ self.set_export_output_modes([self.combo_output.currentText()], follow_current=True)
def export_channel_names(self) -> List[str]:
checked = self.list_export_channels.checked_values()
@@ -2115,12 +2182,28 @@ def export_trigger_names(self) -> List[str]:
checked = self.list_export_dio.checked_values()
return checked or list(self._pending_export_trigger_names)
+ def export_output_modes(self) -> List[str]:
+ checked = self.list_export_outputs.checked_values()
+ modes = checked or list(self._pending_export_output_modes)
+ modes = [mode for mode in modes if mode in OUTPUT_MODES]
+ return modes or [self.combo_output.currentText()]
+
def _on_export_channel_selection_changed(self) -> None:
self._pending_export_channel_names = self.list_export_channels.checked_values()
def _on_export_trigger_selection_changed(self) -> None:
self._pending_export_trigger_names = self.list_export_dio.checked_values()
+ def _on_export_output_selection_changed(self) -> None:
+ self._pending_export_output_modes = self.list_export_outputs.checked_values()
+ self._export_outputs_follow_current_mode = False
+ self._update_section_summaries()
+
+ def _sync_export_outputs_to_current_mode(self) -> None:
+ if not getattr(self, "_export_outputs_follow_current_mode", True):
+ return
+ self.set_export_output_modes([self.combo_output.currentText()], follow_current=True)
+
def set_available_export_channels(self, channels: List[str], preferred: Optional[List[str]] = None) -> None:
current = list(preferred if preferred is not None else self.export_channel_names())
items = [(name, name) for name in channels or [] if str(name or "").strip()]
@@ -2143,6 +2226,20 @@ def set_export_trigger_names(self, trigger_names: List[str]) -> None:
self._pending_export_trigger_names = names
self.list_export_dio.set_checked_values(names)
+ def set_export_output_modes(self, output_modes: List[str], follow_current: bool = False) -> None:
+ modes = [str(mode or "").strip() for mode in output_modes or [] if str(mode or "").strip() in OUTPUT_MODES]
+ if not modes:
+ modes = [self.combo_output.currentText()]
+ self._pending_export_output_modes = modes
+ self._export_outputs_follow_current_mode = bool(follow_current)
+ self.list_export_outputs.set_checked_values(modes)
+
+ def auto_export_enabled(self) -> bool:
+ return bool(self.chk_auto_export.isChecked())
+
+ def set_auto_export_enabled(self, enabled: bool) -> None:
+ self.chk_auto_export.setChecked(bool(enabled))
+
def _save_config(self) -> None:
"""Save current preprocessing parameters to a JSON file."""
params = self.get_params()
@@ -2219,6 +2316,19 @@ def _load_config(self) -> None:
except Exception as e:
QtWidgets.QMessageBox.warning(self, "Error", f"Failed to load config: {e}")
+ def _reset_defaults(self) -> None:
+ """Restore processing parameters and export toggles to application defaults."""
+ self.set_params(ProcessingParams())
+ self.cb_artifact.setChecked(True)
+ self.cb_filtering.setChecked(True)
+ self.cb_show_artifact_overlay.setChecked(True)
+ self.set_export_selection(ExportSelection())
+ self.set_export_output_modes([self.combo_output.currentText()], follow_current=True)
+ self.chk_auto_export.setChecked(False)
+ self._update_auto_export_controls()
+ self._update_section_summaries()
+ self.paramsChanged.emit()
+
def _lambda_value(self) -> float:
x = float(self.spin_lam_x.value())
y = int(self.spin_lam_y.value())
@@ -2228,6 +2338,7 @@ def get_params(self) -> ProcessingParams:
return ProcessingParams(
artifact_detection_enabled=self.cb_artifact.isChecked(),
artifact_mode=self.combo_artifact.currentText(),
+ artifact_handling=self.combo_artifact_handling.currentText(),
mad_k=float(self.spin_mad.value()),
adaptive_window_s=float(self.spin_adapt_win.value()),
artifact_pad_s=float(self.spin_pad.value()),
@@ -2263,6 +2374,7 @@ def set_params(self, params: ProcessingParams) -> None:
return
self.cb_artifact.setChecked(bool(getattr(params, "artifact_detection_enabled", True)))
self.combo_artifact.setCurrentText(str(params.artifact_mode))
+ self.combo_artifact_handling.setCurrentText(str(getattr(params, "artifact_handling", "Interpolate")))
self.spin_mad.setValue(float(params.mad_k))
self.spin_adapt_win.setValue(float(params.adaptive_window_s))
self.spin_pad.setValue(float(params.artifact_pad_s))
@@ -2406,6 +2518,8 @@ class PlotDashboard(QtWidgets.QWidget):
manualRegionFromSelectorRequested = QtCore.Signal()
manualRegionFromDragRequested = QtCore.Signal(float, float)
clearManualRegionsRequested = QtCore.Signal()
+ undoRequested = QtCore.Signal()
+ redoRequested = QtCore.Signal()
showArtifactsRequested = QtCore.Signal()
boxSelectionCleared = QtCore.Signal()
boxSelectionContextRequested = QtCore.Signal()
@@ -2416,8 +2530,10 @@ class PlotDashboard(QtWidgets.QWidget):
def __init__(self, parent=None) -> None:
super().__init__(parent)
self._sync_guard = False
+ self._last_xrange: Optional[Tuple[float, float]] = None
self._artifact_overlay_visible = True
self._artifact_thresholds_visible = True
+ self._dio_overlay_visible = False
self._plot_background_mode = "dark"
self._plot_grid_visible = True
self._artifact_regions: List[pg.LinearRegionItem] = []
@@ -2450,6 +2566,8 @@ def _build_ui(self) -> None:
tools = QtWidgets.QHBoxLayout()
self.btn_add_region = QtWidgets.QPushButton("Add from selector")
self.btn_clear_regions = QtWidgets.QPushButton("Clear manual")
+ self.btn_undo = QtWidgets.QPushButton("Undo")
+ self.btn_redo = QtWidgets.QPushButton("Redo")
self.btn_artifacts = QtWidgets.QPushButton("Artifacts")
self.btn_box_select = QtWidgets.QPushButton("Box select")
self.btn_box_select.setCheckable(True)
@@ -2459,6 +2577,8 @@ def _build_ui(self) -> None:
for b in (
self.btn_add_region,
self.btn_clear_regions,
+ self.btn_undo,
+ self.btn_redo,
self.btn_artifacts,
self.btn_box_select,
self.btn_thresholds,
@@ -2466,6 +2586,8 @@ def _build_ui(self) -> None:
b.setProperty("class", "compactSmall")
tools.addWidget(self.btn_add_region)
tools.addWidget(self.btn_clear_regions)
+ tools.addWidget(self.btn_undo)
+ tools.addWidget(self.btn_redo)
tools.addWidget(self.btn_artifacts)
tools.addWidget(self.btn_box_select)
tools.addWidget(self.btn_thresholds)
@@ -2479,17 +2601,9 @@ def _build_ui(self) -> None:
for w in (self.plot_raw, self.plot_proc, self.plot_out):
_optimize_plot(w)
- # Primary axis for 465 signal
+ # Raw traces share the same primary y-axis.
self.curve_465 = self.plot_raw.plot(pen=pg.mkPen((80, 250, 160), width=1.3))
-
- # Twin axis for 405 (isobestic) signal - share Y axis with 465
- self.plot_raw_pi = self.plot_raw.getPlotItem()
- self.plot_raw_pi.showAxis("right")
- self.plot_raw_pi.getAxis("right").setLabel("405 (isobestic)", color=(160, 120, 255))
-
- # For true twin axis, we plot both curves on the same plot area
- # The 405 curve will use the same Y axis scaling as 465
- self.curve_405 = self.plot_raw.plot(pen=pg.mkPen((160, 120, 255, 128), width=1.2)) # Alpha for isobestic
+ self.curve_405 = self.plot_raw.plot(pen=pg.mkPen((160, 120, 255, 128), width=1.2))
pen_env = pg.mkPen((240, 200, 90), width=1.0, style=QtCore.Qt.PenStyle.DashLine)
self.curve_thr_hi = self.plot_raw.plot(pen=pen_env)
@@ -2524,6 +2638,7 @@ def _build_ui(self) -> None:
self.vb_dio_raw, self.curve_dio_raw = self._add_dio_axis(self.plot_raw, "A/D")
self.vb_dio_proc, self.curve_dio_proc = self._add_dio_axis(self.plot_proc, "A/D")
self.vb_dio_out, self.curve_dio_out = self._add_dio_axis(self.plot_out, "A/D")
+ self._set_dio_overlay_visible(False)
self._align_plot_axis_layouts()
self.lbl_log = QtWidgets.QLabel("")
@@ -2540,6 +2655,8 @@ def _build_ui(self) -> None:
self.btn_add_region.clicked.connect(self.manualRegionFromSelectorRequested.emit)
self.btn_clear_regions.clicked.connect(self.clearManualRegionsRequested.emit)
+ self.btn_undo.clicked.connect(self.undoRequested.emit)
+ self.btn_redo.clicked.connect(self.redoRequested.emit)
self.btn_artifacts.clicked.connect(self.showArtifactsRequested.emit)
self.btn_box_select.toggled.connect(self._toggle_box_select)
self.btn_thresholds.toggled.connect(self._on_thresholds_toggled)
@@ -2555,8 +2672,13 @@ def _build_ui(self) -> None:
self._sync_artifact_threshold_curves_visibility()
self._toggle_box_select(False)
+ self.set_history_available(False, False)
self.set_plot_appearance(self._plot_background_mode, self._plot_grid_visible)
+ def set_history_available(self, can_undo: bool, can_redo: bool) -> None:
+ self.btn_undo.setEnabled(bool(can_undo))
+ self.btn_redo.setEnabled(bool(can_redo))
+
def _normalize_plot_background_mode(self, value: object) -> str:
mode = str(value or "").strip().lower()
if mode in {"white", "light", "w"}:
@@ -2605,7 +2727,7 @@ def _align_plot_axis_layouts(self) -> None:
# Fixed side gutters keep the three view boxes the same width, so shared x ranges
# produce vertically aligned ticks and grid lines across the stacked plots.
left_width = 64
- right_width = 56
+ right_width = 56 if self._dio_overlay_visible else 0
bottom_height = 24
for plot in (self.plot_raw, self.plot_proc, self.plot_out):
pi = plot.getPlotItem()
@@ -2624,6 +2746,34 @@ def _align_plot_axis_layouts(self) -> None:
except Exception:
pass
+ def _set_dio_overlay_visible(self, visible: bool, label: str = "A/D") -> None:
+ self._dio_overlay_visible = bool(visible)
+ for plot, vb, curve in (
+ (self.plot_raw, self.vb_dio_raw, self.curve_dio_raw),
+ (self.plot_proc, self.vb_dio_proc, self.curve_dio_proc),
+ (self.plot_out, self.vb_dio_out, self.curve_dio_out),
+ ):
+ pi = plot.getPlotItem()
+ axis = pi.getAxis("right")
+ if axis is not None:
+ try:
+ axis.setLabel(label if self._dio_overlay_visible else "")
+ except Exception:
+ pass
+ try:
+ pi.showAxis("right", self._dio_overlay_visible)
+ except Exception:
+ pass
+ try:
+ vb.setVisible(self._dio_overlay_visible)
+ except Exception:
+ pass
+ try:
+ curve.setVisible(self._dio_overlay_visible)
+ except Exception:
+ pass
+ self._align_plot_axis_layouts()
+
def _add_dio_axis(self, plot: pg.PlotWidget, label: str):
pi = plot.getPlotItem()
vb = pg.ViewBox()
@@ -2651,6 +2801,7 @@ def _emit_xrange_from_any(self, _vb, x_range) -> None:
return
try:
x0, x1 = x_range
+ self._last_xrange = (float(x0), float(x1))
self.xRangeChanged.emit(float(x0), float(x1))
except Exception:
pass
@@ -2684,6 +2835,9 @@ def _on_thresholds_toggled(self, checked: bool) -> None:
self.artifactThresholdsToggled.emit(bool(checked))
def set_xrange_all(self, x0: float, x1: float) -> None:
+ if not np.isfinite(x0) or not np.isfinite(x1) or float(x1) <= float(x0):
+ return
+ self._last_xrange = (float(x0), float(x1))
self._sync_guard = True
try:
self.plot_raw.setXRange(x0, x1, padding=0)
@@ -2692,6 +2846,20 @@ def set_xrange_all(self, x0: float, x1: float) -> None:
finally:
self._sync_guard = False
+ def current_xrange(self) -> Optional[Tuple[float, float]]:
+ try:
+ (x0, x1), _ = self.plot_raw.getViewBox().viewRange()
+ if np.isfinite(x0) and np.isfinite(x1) and x1 > x0:
+ return (float(x0), float(x1))
+ except Exception:
+ pass
+ if self._last_xrange is None:
+ return None
+ x0, x1 = self._last_xrange
+ if np.isfinite(x0) and np.isfinite(x1) and x1 > x0:
+ return (float(x0), float(x1))
+ return None
+
def set_full_xrange(self, t: np.ndarray) -> None:
if t is None or np.asarray(t).size < 2:
return
@@ -2915,6 +3083,7 @@ def _set_dio(self, t: np.ndarray, dio: Optional[np.ndarray], name: str = "") ->
self.curve_dio_raw.setData([], [])
self.curve_dio_proc.setData([], [])
self.curve_dio_out.setData([], [])
+ self._set_dio_overlay_visible(False)
return
tt = np.asarray(t, float)
@@ -2925,16 +3094,7 @@ def _set_dio(self, t: np.ndarray, dio: Optional[np.ndarray], name: str = "") ->
self.curve_dio_raw.setData(tt, yy, connect="finite", skipFiniteCheck=True)
self.curve_dio_proc.setData(tt, yy, connect="finite", skipFiniteCheck=True)
self.curve_dio_out.setData(tt, yy, connect="finite", skipFiniteCheck=True)
-
- if name:
- self.plot_raw.getPlotItem().getAxis("right").setLabel(f"A/D ({name})")
- self.plot_proc.getPlotItem().getAxis("right").setLabel(f"A/D ({name})")
- self.plot_out.getPlotItem().getAxis("right").setLabel(f"A/D ({name})")
- else:
- self.plot_raw.getPlotItem().getAxis("right").setLabel("A/D")
- self.plot_proc.getPlotItem().getAxis("right").setLabel("A/D")
- self.plot_out.getPlotItem().getAxis("right").setLabel("A/D")
- self._align_plot_axis_layouts()
+ self._set_dio_overlay_visible(True, label=(f"A/D ({name})" if name else "A/D"))
# -------------------- Compatibility API expected by main.py --------------------
@@ -3123,11 +3283,12 @@ def show_output(self, *args, **kwargs) -> None:
n = min(t.size, y.size)
t, y = t[:n], y[:n]
+ label = _first_not_none(kwargs, "label", "output_label", default="Output")
+
self.curve_out.setData(t, y, connect="finite", skipFiniteCheck=True)
if not preserve_view:
self.set_full_xrange(t)
- label = _first_not_none(kwargs, "label", "output_label", default="Output")
context = _first_not_none(kwargs, "output_context", "label_context", default="")
title = f"Output: {label}" if not context else f"Output: {label} | {context}"
self.plot_out.setTitle(title)
@@ -3140,6 +3301,7 @@ def show_output(self, *args, **kwargs) -> None:
def update_plots(self, processed: ProcessedTrial, preserve_view: bool = False) -> None:
t = np.asarray(processed.time, float)
+ kept_xrange = self.current_xrange() if preserve_view else None
self.show_raw(
t, processed.raw_signal, processed.raw_reference,
dio=processed.dio, dio_name=processed.dio_name,
@@ -3161,4 +3323,6 @@ def update_plots(self, processed: ProcessedTrial, preserve_view: bool = False) -
preserve_view=preserve_view,
)
self._update_artifact_overlays(t, processed.raw_signal, processed.artifact_regions_sec)
+ if kept_xrange is not None:
+ self.set_xrange_all(*kept_xrange)
diff --git a/pyBer/main.py b/pyBer/main.py
index 47e0afe..fe65d20 100644
--- a/pyBer/main.py
+++ b/pyBer/main.py
@@ -15,7 +15,70 @@
import json
import logging
import sys
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+
+_DLL_DIR_HANDLES = []
+
+
+def _bootstrap_windows_conda_runtime() -> None:
+ if os.name != "nt":
+ return
+
+ os.environ.setdefault("PYTHONNOUSERSITE", "1")
+
+ try:
+ import site
+ user_site = os.path.normcase(os.path.abspath(site.getusersitepackages()))
+ except Exception:
+ user_site = ""
+
+ appdata_python = ""
+ appdata = os.environ.get("APPDATA", "")
+ if appdata:
+ appdata_python = os.path.normcase(os.path.abspath(os.path.join(appdata, "Python")))
+
+ def _is_user_site_path(path: str) -> bool:
+ if not path:
+ return False
+ try:
+ norm = os.path.normcase(os.path.abspath(path))
+ except Exception:
+ return False
+ if user_site and (norm == user_site or norm.startswith(user_site + os.sep)):
+ return True
+ return bool(appdata_python and (norm == appdata_python or norm.startswith(appdata_python + os.sep)))
+
+ sys.path[:] = [path for path in sys.path if not _is_user_site_path(path)]
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ if script_dir and script_dir not in sys.path:
+ sys.path.insert(0, script_dir)
+
+ prefix = os.environ.get("CONDA_PREFIX") or sys.prefix
+ dll_dirs = [
+ prefix,
+ os.path.join(prefix, "Library", "mingw-w64", "bin"),
+ os.path.join(prefix, "Library", "usr", "bin"),
+ os.path.join(prefix, "Library", "bin"),
+ os.path.join(prefix, "Scripts"),
+ ]
+ existing = [path for path in dll_dirs if path and os.path.isdir(path)]
+
+ if hasattr(os, "add_dll_directory"):
+ for path in existing:
+ try:
+ _DLL_DIR_HANDLES.append(os.add_dll_directory(path))
+ except Exception:
+ pass
+
+ old_path = os.environ.get("PATH", "")
+ old_parts = [os.path.normcase(os.path.abspath(p)) for p in old_path.split(os.pathsep) if p]
+ prepend = [p for p in existing if os.path.normcase(os.path.abspath(p)) not in old_parts]
+ if prepend:
+ os.environ["PATH"] = os.pathsep.join(prepend + [old_path])
+
+
+_bootstrap_windows_conda_runtime()
from PySide6 import QtCore, QtGui, QtWidgets
import pyqtgraph as pg
@@ -24,6 +87,7 @@
from analysis_core import (
ExportSelection,
+ OUTPUT_MODES,
PhotometryProcessor,
ProcessingParams,
LoadedDoricFile,
@@ -37,6 +101,7 @@
zscore_median_std,
safe_divide,
_lowpass_sos,
+ coerce_time_value,
)
from gui_preprocessing import (
FileQueuePanel,
@@ -47,7 +112,20 @@
AdvancedOptionsDialog,
)
from gui_postprocessing import PostProcessingPanel
+from numeric_controls import install_spinbox_scrubbers
+from onboarding import (
+ ToastManager,
+ TutorialOverlay,
+ PreferencesDialog,
+ register_global_shortcuts,
+ attach_dirty_title,
+ install_close_confirmation,
+ reset_focused_plot_view,
+ build_default_tutorial,
+ add_empty_state_hint,
+)
from styles import (
+ apply_app_palette,
app_qss,
_make_icon,
_paint_database,
@@ -487,6 +565,12 @@ def __init__(self) -> None:
self._pending_main_tab_index: Optional[int] = None
self._force_fixed_dock_layouts: bool = bool(_FORCE_FIXED_DOCK_LAYOUTS)
self._app_theme_mode: str = "dark"
+ self._pre_history_undo: List[Dict[str, Any]] = []
+ self._pre_history_redo: List[Dict[str, Any]] = []
+ self._pre_history_current: Optional[Dict[str, Any]] = None
+ self._pre_history_key: str = ""
+ self._pre_history_restoring: bool = False
+ self._pre_history_limit: int = 60
# Worker infra (stable)
self._pool = QtCore.QThreadPool.globalInstance()
@@ -510,6 +594,7 @@ def __init__(self) -> None:
self._build_ui()
self._restore_settings()
self._panel_layout_persistence_ready = True
+ self._reset_pre_history_snapshot()
# Enforce: preprocessing drawer is hidden until the user
# explicitly clicks a rail section button (overrides any saved state).
self._force_hide_pre_drawer_initially()
@@ -556,6 +641,35 @@ def _build_ui(self) -> None:
self._status_bar.addPermanentWidget(QtWidgets.QLabel("App theme"))
self._status_bar.addPermanentWidget(self.btn_app_theme)
+ # Busy / cancel indicator (left of theme widgets). Hidden until something runs.
+ self._busy_widget = QtWidgets.QFrame()
+ self._busy_widget.setObjectName("pyberBusyWidget")
+ self._busy_widget.setStyleSheet(
+ "QFrame#pyberBusyWidget { background: #2a3045; border: 1px solid #46527a;"
+ " border-radius: 6px; padding: 0 8px; }"
+ "QFrame#pyberBusyWidget QLabel { background: transparent; color: #d7e0ee; }"
+ "QFrame#pyberBusyWidget QPushButton { background: #543035; color: #ffd6dc;"
+ " border: 1px solid #8a3949; border-radius: 4px; padding: 1px 8px; }"
+ "QFrame#pyberBusyWidget QPushButton:hover { background: #6b3a40; }"
+ )
+ bl = QtWidgets.QHBoxLayout(self._busy_widget)
+ bl.setContentsMargins(6, 1, 6, 1)
+ bl.setSpacing(8)
+ self._busy_label = QtWidgets.QLabel("Busy...")
+ bl.addWidget(self._busy_label)
+ self._busy_cancel = QtWidgets.QPushButton("Cancel")
+ self._busy_cancel.setToolTip("Cancel the running batch operation (Esc).")
+ self._busy_cancel.clicked.connect(self._cancel_current_operation)
+ bl.addWidget(self._busy_cancel)
+ self._busy_widget.setVisible(False)
+ self._status_bar.addPermanentWidget(self._busy_widget)
+
+ # Poll the temporal panel's progress bar to surface batch state in status bar.
+ self._busy_poll = QtCore.QTimer(self)
+ self._busy_poll.setInterval(250)
+ self._busy_poll.timeout.connect(self._update_busy_indicator)
+ self._busy_poll.start()
+
# Preprocessing tab
self.pre_tab = QtWidgets.QWidget()
self.tabs.addTab(self.pre_tab, "Preprocessing")
@@ -866,6 +980,8 @@ def _build_ui(self) -> None:
self.plots.manualRegionFromSelectorRequested.connect(self._add_manual_region_from_selector)
self.plots.manualRegionFromDragRequested.connect(self._add_manual_region_from_drag)
self.plots.clearManualRegionsRequested.connect(self._clear_manual_regions_current)
+ self.plots.undoRequested.connect(self._undo_pre_action)
+ self.plots.redoRequested.connect(self._redo_pre_action)
self.plots.showArtifactsRequested.connect(self._toggle_artifacts_panel)
self.plots.boxSelectionCleared.connect(self._cancel_box_select_request)
self.plots.boxSelectionContextRequested.connect(self._show_box_selection_context_menu)
@@ -886,6 +1002,41 @@ def _build_ui(self) -> None:
self._update_plot_status()
self.setAcceptDrops(True)
+ # ----- UX polish: toasts, dirty-title, global shortcuts, tutorial -----
+ try:
+ self._toaster = ToastManager(self, max_visible=4)
+ except Exception:
+ self._toaster = None
+
+ # Mirror status-bar messages to toasts (longer-lived, easier to spot).
+ try:
+ self.post_tab.statusUpdate.connect(self._toast_from_status)
+ except Exception:
+ pass
+
+ # Dirty-title indicator: '*' suffix while postprocessing has unsaved changes.
+ def _is_dirty() -> bool:
+ try:
+ return bool(getattr(self.post_tab, "_project_dirty", False))
+ except Exception:
+ return False
+
+ self._refresh_dirty_title = attach_dirty_title(
+ self, "Pyber - Fiber Photometry", _is_dirty,
+ )
+ self._dirty_poll = QtCore.QTimer(self)
+ self._dirty_poll.setInterval(800)
+ self._dirty_poll.timeout.connect(self._refresh_dirty_title)
+ self._dirty_poll.start()
+
+ install_close_confirmation(self, _is_dirty, save_callback=self._save_post_project_for_close)
+
+ # Register the global shortcut bundle. Methods that don't exist become no-ops.
+ register_global_shortcuts(self)
+
+ # First-run tutorial.
+ QtCore.QTimer.singleShot(450, self._maybe_show_first_run_tutorial)
+
def _setup_section_popups(self) -> None:
"""Create preprocessing section panels using DockArea or legacy floating docks."""
if self._use_pg_dockarea_pre_layout and self._pre_dockarea_docks:
@@ -1294,10 +1445,12 @@ def _build_config_actions_widget(self) -> QtWidgets.QWidget:
self.param_panel.btn_metadata.setProperty("class", "blueSecondarySmall")
self.param_panel.btn_save_config.setProperty("class", "blueSecondarySmall")
self.param_panel.btn_load_config.setProperty("class", "blueSecondarySmall")
+ self.param_panel.btn_reset_defaults.setProperty("class", "blueSecondarySmall")
for btn in (
self.param_panel.btn_metadata,
self.param_panel.btn_save_config,
self.param_panel.btn_load_config,
+ self.param_panel.btn_reset_defaults,
):
btn.setSizePolicy(QtWidgets.QSizePolicy.Policy.Ignored, QtWidgets.QSizePolicy.Policy.Fixed)
btn.setMinimumWidth(90)
@@ -1949,6 +2102,232 @@ def _init_shortcuts(self) -> None:
self._bind_shortcut("S", self._assign_pending_box_to_section, require_non_text_focus=True)
self._bind_shortcut("Escape", self._close_focused_popup, require_non_text_focus=True)
+ # ----------------------------------------------------------------------
+ # UX polish: tutorial / toasts / global shortcut callbacks
+ # ----------------------------------------------------------------------
+
+ def _toast_from_status(self, message: str, timeout_ms: int = 0) -> None:
+ if not getattr(self, "_toaster", None) or not message:
+ return
+ text = str(message)
+ lower = text.lower()
+ sev = "info"
+ if "fail" in lower or "error" in lower or "could not" in lower:
+ sev = "error"
+ elif "warn" in lower or "dropped" in lower or "skipped" in lower:
+ sev = "warn"
+ elif "complete" in lower or "saved" in lower or "loaded" in lower or " ok" in lower:
+ sev = "ok"
+ timeout = int(timeout_ms) if timeout_ms and timeout_ms > 0 else int(self.settings.value("ui/toast_timeout_ms", 5000) or 5000)
+ self._toaster.post(text, sev, timeout)
+
+ def _maybe_show_first_run_tutorial(self) -> None:
+ try:
+ seen = self.settings.value("onboarding/first_run_completed", False)
+ show_pref = self.settings.value("onboarding/show_on_startup", True)
+ from onboarding import _to_bool
+ if _to_bool(seen, False) and not _to_bool(show_pref, True):
+ return
+ except Exception:
+ pass
+ self._show_tutorial_again()
+
+ def _show_tutorial_again(self) -> None:
+ try:
+ steps = build_default_tutorial(self)
+ overlay = TutorialOverlay(self, steps)
+ overlay.finished.connect(lambda: self.settings.setValue("onboarding/first_run_completed", True))
+ overlay.start()
+ except Exception:
+ pass
+
+ def _show_keyboard_cheatsheet(self) -> None:
+ # Open the Preferences dialog directly on the Keyboard tab.
+ try:
+ dlg = PreferencesDialog(self, self.settings)
+ for i in range(dlg.findChild(QtWidgets.QTabWidget).count()):
+ tabs = dlg.findChild(QtWidgets.QTabWidget)
+ if tabs.tabText(i).lower().startswith("keyboard"):
+ tabs.setCurrentIndex(i)
+ break
+ dlg.exec()
+ except Exception:
+ pass
+
+ def _open_preferences(self) -> None:
+ try:
+ dlg = PreferencesDialog(self, self.settings)
+ if dlg.exec() == QtWidgets.QDialog.DialogCode.Accepted:
+ # Apply theme right away if it changed.
+ desired = str(self.settings.value("app/theme", "dark") or "dark").lower()
+ current = getattr(self, "_app_theme_mode", "dark")
+ if desired != current:
+ if desired == "light":
+ self.act_app_theme_light.setChecked(True)
+ else:
+ self.act_app_theme_dark.setChecked(True)
+ self._on_app_theme_changed()
+ if getattr(self, "_toaster", None):
+ self._toaster.ok("Preferences saved.")
+ except Exception:
+ pass
+
+ # --- tab navigation ---
+
+ def _focus_pre_tab(self) -> None:
+ try:
+ self.tabs.setCurrentIndex(0)
+ except Exception:
+ pass
+
+ def _focus_post_tab(self) -> None:
+ try:
+ self.tabs.setCurrentIndex(1)
+ except Exception:
+ pass
+
+ def _cycle_main_tab(self) -> None:
+ try:
+ n = self.tabs.count()
+ if n <= 1:
+ return
+ self.tabs.setCurrentIndex((self.tabs.currentIndex() + 1) % n)
+ except Exception:
+ pass
+
+ # --- file navigation in postprocessing/temporal ---
+
+ def _step_active_file_next(self) -> None:
+ self._step_active_file(+1)
+
+ def _step_active_file_prev(self) -> None:
+ self._step_active_file(-1)
+
+ def _step_active_file(self, delta: int) -> None:
+ # Try the temporal panel's combo (covers the GLM scope strip).
+ try:
+ section = getattr(self.post_tab, "section_temporal", None)
+ if section is not None and hasattr(section, "_step_active_file"):
+ section._step_active_file(delta)
+ return
+ except Exception:
+ pass
+ # Fall back to postprocessing's own file combo, if any.
+ try:
+ combo = getattr(self.post_tab, "combo_individual_file", None)
+ if combo is not None:
+ idx = max(0, min(combo.count() - 1, combo.currentIndex() + int(delta)))
+ combo.setCurrentIndex(idx)
+ except Exception:
+ pass
+
+ def _toggle_individual_group(self) -> None:
+ try:
+ bar = getattr(self.post_tab, "tab_visual_mode", None)
+ if bar is None:
+ return
+ cur = bar.currentIndex()
+ bar.setCurrentIndex((cur + 1) % bar.count())
+ except Exception:
+ pass
+
+ # --- temporal modeling ---
+
+ def _fit_temporal_model(self) -> None:
+ try:
+ section = getattr(self.post_tab, "section_temporal", None)
+ if section is None:
+ return
+ self.tabs.setCurrentIndex(1)
+ section._on_fit_clicked()
+ except Exception:
+ pass
+
+ def _fit_temporal_all_files(self) -> None:
+ try:
+ section = getattr(self.post_tab, "section_temporal", None)
+ if section is None:
+ return
+ self.tabs.setCurrentIndex(1)
+ section._on_fit_all_files_clicked()
+ except Exception:
+ pass
+
+ def _recompute_psth(self) -> None:
+ try:
+ self.tabs.setCurrentIndex(1)
+ fn = getattr(self.post_tab, "_compute_psth", None)
+ if callable(fn):
+ fn()
+ except Exception:
+ pass
+
+ def _run_postprocess_export(self) -> None:
+ try:
+ self.tabs.setCurrentIndex(1)
+ for name in ("_run_export", "_export_current", "_on_run_export_clicked", "run_export"):
+ fn = getattr(self.post_tab, name, None)
+ if callable(fn):
+ fn()
+ return
+ except Exception:
+ pass
+
+ def _cancel_current_operation(self) -> None:
+ # Temporal modeling batch
+ try:
+ section = getattr(self.post_tab, "section_temporal", None)
+ if section is not None:
+ section._batch_cancel_requested = True
+ except Exception:
+ pass
+ # Preprocessing has its own Esc handling for popups; let it through too.
+ try:
+ self._close_focused_popup()
+ except Exception:
+ pass
+
+ def _reset_focused_plot_view(self) -> None:
+ try:
+ reset_focused_plot_view(self)
+ if getattr(self, "_toaster", None):
+ self._toaster.info("Reset plot view.", timeout_ms=1800)
+ except Exception:
+ pass
+
+ def _update_busy_indicator(self) -> None:
+ """Reflect any running batch op (currently: Temporal Modeling) in the status bar."""
+ try:
+ section = getattr(self.post_tab, "section_temporal", None)
+ if section is None or not hasattr(section, "progress_model"):
+ self._busy_widget.setVisible(False)
+ return
+ progress = section.progress_model
+ if progress.isVisible():
+ fmt = progress.format() or "Running..."
+ # Strip the trailing "%p%" placeholder for our compact label.
+ label = fmt.replace("%p%", "").strip(" :")
+ self._busy_label.setText(label or "Running...")
+ self._busy_widget.setVisible(True)
+ else:
+ self._busy_widget.setVisible(False)
+ except Exception:
+ self._busy_widget.setVisible(False)
+
+ def _save_post_project_for_close(self) -> bool:
+ """
+ Used by the close-confirmation handler. Returns True on success.
+ """
+ try:
+ fn = getattr(self.post_tab, "_save_project_dialog", None) or getattr(self.post_tab, "_save_project", None)
+ if callable(fn):
+ fn()
+ return not bool(getattr(self.post_tab, "_project_dirty", False))
+ except Exception:
+ pass
+ # No save handler available: let the user decide via Discard/Cancel.
+ return False
+
def _dock_area_from_settings(
self,
value: object,
@@ -2257,6 +2636,7 @@ def _export_preprocessing_ui_state_for_config(self) -> Dict[str, object]:
"export_selection": self.param_panel.export_selection().to_dict(),
"export_channel_names": self.param_panel.export_channel_names(),
"export_trigger_names": self.param_panel.export_trigger_names(),
+ "auto_export_to_source_dir": bool(self.param_panel.auto_export_enabled()),
"panel_layout": self._collect_panel_layout_payload(),
}
@@ -2311,6 +2691,8 @@ def _import_preprocessing_ui_state_from_config(self, ui_state: Dict[str, object]
self.param_panel.set_export_channel_names(list(ui_state.get("export_channel_names") or []))
if "export_trigger_names" in ui_state:
self.param_panel.set_export_trigger_names(list(ui_state.get("export_trigger_names") or []))
+ if "auto_export_to_source_dir" in ui_state:
+ self.param_panel.set_auto_export_enabled(_to_bool(ui_state.get("auto_export_to_source_dir"), False))
self._update_export_summary_label()
panel_layout = ui_state.get("panel_layout")
if isinstance(panel_layout, dict):
@@ -2319,6 +2701,166 @@ def _import_preprocessing_ui_state_from_config(self, ui_state: Dict[str, object]
self._save_panel_config_json()
self._save_settings()
+ def _pre_history_clone(self, state: Dict[str, Any]) -> Dict[str, Any]:
+ try:
+ return json.loads(json.dumps(state))
+ except Exception:
+ return dict(state)
+
+ def _pre_history_state_key(self, state: Dict[str, Any]) -> str:
+ try:
+ return json.dumps(state, sort_keys=True, separators=(",", ":"))
+ except Exception:
+ return repr(state)
+
+ def _pre_history_snapshot(self) -> Dict[str, Any]:
+ try:
+ params = self.param_panel.get_params().to_dict()
+ except Exception:
+ params = {}
+ start_s, end_s = self._time_window_bounds()
+ return {
+ "params": params,
+ "artifact_overlay_visible": bool(self.param_panel.artifact_overlay_visible()),
+ "artifact_thresholds_visible": bool(self.plots.artifact_thresholds_visible()),
+ "plot_background": self.plots.plot_background_mode(),
+ "plot_grid": bool(self.plots.plot_grid_visible()),
+ "export_selection": self.param_panel.export_selection().to_dict(),
+ "export_channel_names": self.param_panel.export_channel_names(),
+ "export_trigger_names": self.param_panel.export_trigger_names(),
+ "export_output_modes": self.param_panel.export_output_modes(),
+ "auto_export_to_source_dir": bool(self.param_panel.auto_export_enabled()),
+ "time_window": {"start_s": start_s, "end_s": end_s},
+ "manual_regions": self._keyed_regions_to_project(self._manual_regions_by_key),
+ "manual_exclude_regions": self._keyed_regions_to_project(self._manual_exclude_by_key),
+ "cutout_regions": self._keyed_regions_to_project(self._cutout_regions_by_key),
+ "sections": self._sections_to_project(),
+ }
+
+ def _update_pre_history_buttons(self) -> None:
+ try:
+ self.plots.set_history_available(bool(self._pre_history_undo), bool(self._pre_history_redo))
+ except Exception:
+ pass
+
+ def _reset_pre_history_snapshot(self) -> None:
+ self._pre_history_undo.clear()
+ self._pre_history_redo.clear()
+ self._pre_history_current = self._pre_history_snapshot()
+ self._pre_history_key = self._pre_history_state_key(self._pre_history_current)
+ self._update_pre_history_buttons()
+
+ def _record_pre_history_change(self) -> None:
+ if self._pre_history_restoring:
+ return
+ state = self._pre_history_snapshot()
+ key = self._pre_history_state_key(state)
+ if self._pre_history_current is None:
+ self._pre_history_current = self._pre_history_clone(state)
+ self._pre_history_key = key
+ self._update_pre_history_buttons()
+ return
+ if key == self._pre_history_key:
+ return
+ self._pre_history_undo.append(self._pre_history_clone(self._pre_history_current))
+ if len(self._pre_history_undo) > self._pre_history_limit:
+ self._pre_history_undo = self._pre_history_undo[-self._pre_history_limit:]
+ self._pre_history_redo.clear()
+ self._pre_history_current = self._pre_history_clone(state)
+ self._pre_history_key = key
+ self._update_pre_history_buttons()
+
+ def _refresh_artifact_panel_for_current(self) -> None:
+ key = self._current_key()
+ if not key:
+ self.artifact_panel.set_auto_regions([])
+ self.artifact_panel.set_regions([])
+ return
+ start_s, end_s = self._time_window_bounds()
+ manual_win = self._clip_regions_to_window(self._manual_regions_by_key.get(key, []), start_s, end_s)
+ ignore_win = self._clip_regions_to_window(self._manual_exclude_by_key.get(key, []), start_s, end_s)
+ auto_win = self._clip_regions_to_window(self._auto_regions_by_key.get(key, []), start_s, end_s)
+ checked_auto = [r for r in auto_win if not any(self._regions_match(r, ig) for ig in ignore_win)]
+ self.artifact_panel.set_auto_regions(auto_win, checked_regions=checked_auto)
+ self.artifact_panel.set_regions(manual_win)
+
+ def _restore_pre_history_state(self, state: Dict[str, Any]) -> None:
+ self._pre_history_restoring = True
+ try:
+ params = state.get("params")
+ if isinstance(params, dict):
+ self.param_panel.set_params(ProcessingParams.from_dict(params))
+ if "artifact_overlay_visible" in state:
+ visible = bool(state.get("artifact_overlay_visible"))
+ self.param_panel.set_artifact_overlay_visible(visible)
+ self.plots.set_artifact_overlay_visible(visible)
+ if "artifact_thresholds_visible" in state:
+ self.plots.set_artifact_thresholds_visible(bool(state.get("artifact_thresholds_visible")))
+ if "export_selection" in state:
+ self.param_panel.set_export_selection(ExportSelection.from_dict(state.get("export_selection")))
+ if "export_output_modes" in state:
+ self.param_panel.set_export_output_modes(list(state.get("export_output_modes") or []), follow_current=False)
+ if "export_channel_names" in state:
+ self.param_panel.set_export_channel_names(list(state.get("export_channel_names") or []))
+ if "export_trigger_names" in state:
+ self.param_panel.set_export_trigger_names(list(state.get("export_trigger_names") or []))
+ if "auto_export_to_source_dir" in state:
+ self.param_panel.set_auto_export_enabled(_to_bool(state.get("auto_export_to_source_dir"), False))
+ self._apply_pre_plot_style(
+ state.get("plot_background", self.plots.plot_background_mode()),
+ state.get("plot_grid", self.plots.plot_grid_visible()),
+ persist=False,
+ )
+ self._manual_regions_by_key = self._project_to_keyed_regions(state.get("manual_regions"))
+ self._manual_exclude_by_key = self._project_to_keyed_regions(state.get("manual_exclude_regions"))
+ self._cutout_regions_by_key = self._project_to_keyed_regions(state.get("cutout_regions"))
+ self._sections_by_key = self._project_to_sections(state.get("sections"))
+ self._pending_box_region_by_key.clear()
+ tw = state.get("time_window") if isinstance(state.get("time_window"), dict) else {}
+ for ed, value in (
+ (self.file_panel.edit_time_start, tw.get("start_s")),
+ (self.file_panel.edit_time_end, tw.get("end_s")),
+ ):
+ ed.blockSignals(True)
+ try:
+ ed.setText("" if value is None else f"{float(value):.6g}")
+ except Exception:
+ ed.setText("")
+ finally:
+ ed.blockSignals(False)
+ self._last_processed.clear()
+ self._refresh_artifact_panel_for_current()
+ self._update_export_summary_label()
+ self._update_raw_plot(preserve_view=True)
+ self._trigger_preview(preserve_view=True)
+ self._save_settings()
+ finally:
+ self._pre_history_restoring = False
+
+ def _undo_pre_action(self) -> None:
+ if not self._pre_history_undo:
+ return
+ current = self._pre_history_snapshot()
+ previous = self._pre_history_undo.pop()
+ self._pre_history_redo.append(self._pre_history_clone(current))
+ self._restore_pre_history_state(previous)
+ self._pre_history_current = self._pre_history_clone(previous)
+ self._pre_history_key = self._pre_history_state_key(previous)
+ self._update_pre_history_buttons()
+ self._show_status_message("Undid preprocessing action.", 2500)
+
+ def _redo_pre_action(self) -> None:
+ if not self._pre_history_redo:
+ return
+ current = self._pre_history_snapshot()
+ next_state = self._pre_history_redo.pop()
+ self._pre_history_undo.append(self._pre_history_clone(current))
+ self._restore_pre_history_state(next_state)
+ self._pre_history_current = self._pre_history_clone(next_state)
+ self._pre_history_key = self._pre_history_state_key(next_state)
+ self._update_pre_history_buttons()
+ self._show_status_message("Redid preprocessing action.", 2500)
+
def _sync_section_button_states_from_docks(self) -> None:
if self._use_pg_dockarea_pre_layout:
self._last_opened_section = None
@@ -2585,6 +3127,12 @@ def _restore_settings(self) -> None:
self.plots.set_artifact_thresholds_visible(bool(show_thresholds))
except Exception:
pass
+ try:
+ auto_export = _to_bool(self.settings.value("auto_export_to_source_dir", False), False)
+ self.param_panel.set_auto_export_enabled(auto_export)
+ self._update_export_summary_label()
+ except Exception:
+ pass
try:
default_bg = "white" if self._app_theme_mode == "light" else "dark"
plot_bg = self.settings.value("pre_plot_background", default_bg, type=str)
@@ -2704,6 +3252,10 @@ def _save_settings(self) -> None:
self.settings.setValue("artifact_thresholds_visible", bool(self.plots.artifact_thresholds_visible()))
except Exception:
pass
+ try:
+ self.settings.setValue("auto_export_to_source_dir", bool(self.param_panel.auto_export_enabled()))
+ except Exception:
+ pass
try:
self.settings.setValue("pre_plot_background", str(self.plots.plot_background_mode()))
self.settings.setValue("pre_plot_grid", bool(self.plots.plot_grid_visible()))
@@ -2882,6 +3434,11 @@ def _new_preprocessing_project(self) -> None:
return
self._clear_preprocessing_project_state()
self._pre_project_path = None
+ try:
+ self.post_tab.reset_for_new_preprocessing_project()
+ except Exception:
+ pass
+ self._reset_pre_history_snapshot()
self._show_status_message("Started a new preprocessing project.", 5000)
def _keyed_regions_to_project(
@@ -3083,6 +3640,10 @@ def _load_preprocessing_project_from_path(self, path: str) -> None:
self._clear_preprocessing_project_state()
self._pre_project_path = path
+ try:
+ self.post_tab.reset_for_new_preprocessing_project()
+ except Exception:
+ pass
self._apply_preprocessing_config_payload(payload.get("preprocessing_config"))
session_mapping = payload.get("csv_mapping_session")
@@ -3136,6 +3697,7 @@ def _load_preprocessing_project_from_path(self, path: str) -> None:
"Open project",
"Some linked input files are missing and were skipped:\n" + "\n".join(missing_paths[:12]),
)
+ self._reset_pre_history_snapshot()
self._show_status_message(f"Preprocessing project loaded: {os.path.basename(path)}", 5000)
def _restore_file_selection(self, selected_paths: List[str], current_path: Optional[str]) -> None:
@@ -3174,7 +3736,6 @@ def _is_csv_time_column(self, value: object) -> bool:
return norm in {"time", "t", "timestamp", "times", "timesec", "times", "timems"} or "timestamp" in norm
def _parse_csv_float(self, value: object) -> float:
- from pyBer.analysis_core import coerce_time_value
text = str(value or "").strip()
if not text or text.lower() in {"nan", "none", "null", "na"}:
return np.nan
@@ -4209,10 +4770,12 @@ def _on_main_tab_changed(self, index: int) -> None:
def _on_artifact_overlay_toggled(self, visible: bool) -> None:
self.plots.set_artifact_overlay_visible(bool(visible))
+ self._record_pre_history_change()
self._save_settings()
def _on_artifact_thresholds_toggled(self, visible: bool) -> None:
self.plots.set_artifact_thresholds_visible(bool(visible))
+ self._record_pre_history_change()
self._save_settings()
def _normalize_app_theme_mode(self, value: object) -> str:
@@ -4240,6 +4803,7 @@ def _apply_app_theme(self, theme_mode: object, persist: bool = True) -> None:
self.act_app_theme_light.blockSignals(False)
try:
+ apply_app_palette(QtWidgets.QApplication.instance(), mode)
self.setStyleSheet(app_qss(mode))
except Exception:
pass
@@ -4299,6 +4863,7 @@ def _on_pre_plot_style_changed(self, *_args) -> None:
self.act_plot_grid.isChecked() if hasattr(self, "act_plot_grid") else True,
persist=True,
)
+ self._record_pre_history_change()
def _auto_range_for_processed(self, processed: ProcessedTrial) -> None:
try:
@@ -4471,6 +5036,7 @@ def _on_time_window_changed(self) -> None:
checked_auto = [r for r in auto_win if not any(self._regions_match(r, ig) for ig in ignore_win)]
self.artifact_panel.set_auto_regions(auto_win, checked_regions=checked_auto)
self.artifact_panel.set_regions(manual_win)
+ self._record_pre_history_change()
self._update_raw_plot()
self._trigger_preview()
self._update_plot_status()
@@ -4761,6 +5327,13 @@ def _mask_arr(arr: Optional[np.ndarray]) -> Optional[np.ndarray]:
processed.baseline_sig = _mask_arr(processed.baseline_sig)
processed.baseline_ref = _mask_arr(processed.baseline_ref)
processed.output = _mask_arr(processed.output)
+ if hasattr(processed, "outputs") and processed.outputs:
+ masked_outputs = {}
+ for label, values in processed.outputs.items():
+ masked = _mask_arr(values)
+ if masked is not None:
+ masked_outputs[str(label)] = masked
+ processed.outputs = masked_outputs
# Mask triggers too if requested by convention, but here we keep them as-is or NaN them
if hasattr(processed, "triggers") and processed.triggers:
@@ -4860,12 +5433,15 @@ def _artifact_param_signature(self, params: ProcessingParams) -> Tuple[object, .
return (
bool(getattr(params, "artifact_detection_enabled", True)),
str(params.artifact_mode),
+ str(getattr(params, "artifact_handling", "Interpolate")),
float(params.mad_k),
float(params.adaptive_window_s),
float(params.artifact_pad_s),
)
def _on_params_changed(self) -> None:
+ if self._pre_history_restoring:
+ return
try:
params = self.param_panel.get_params()
except Exception:
@@ -4889,6 +5465,7 @@ def _on_params_changed(self) -> None:
self._update_raw_plot(preserve_view=True)
except Exception:
pass
+ self._record_pre_history_change()
self._trigger_preview(preserve_view=True)
def _trigger_preview(self, preserve_view: bool = False) -> None:
@@ -5046,7 +5623,8 @@ def _add_manual_region_from_selector(self) -> None:
self._manual_regions_by_key[key] = regs
start_s, end_s = self._time_window_bounds()
self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s))
- self._trigger_preview()
+ self._record_pre_history_change()
+ self._trigger_preview(preserve_view=True)
def _add_manual_region_from_drag(self, t0: float, t1: float) -> None:
if self._box_select_callback:
@@ -5073,7 +5651,8 @@ def _clear_manual_regions_current(self) -> None:
self._manual_exclude_by_key[key] = []
self._pending_box_region_by_key.pop(key, None)
self.artifact_panel.set_regions([])
- self._trigger_preview()
+ self._record_pre_history_change()
+ self._trigger_preview(preserve_view=True)
def _request_box_select(self, callback: Callable[[float, float], None]) -> None:
self._box_select_callback = callback
@@ -5132,7 +5711,8 @@ def _assign_pending_box_to_artifact(self) -> None:
self._manual_regions_by_key[key] = regs
start_s, end_s = self._time_window_bounds()
self.artifact_panel.set_regions(self._clip_regions_to_window(regs, start_s, end_s))
- self._trigger_preview()
+ self._record_pre_history_change()
+ self._trigger_preview(preserve_view=True)
def _assign_pending_box_to_cut(self) -> None:
region = self._consume_pending_box_region()
@@ -5144,6 +5724,7 @@ def _assign_pending_box_to_cut(self) -> None:
regs.sort(key=lambda x: x[0])
self._cutout_regions_by_key[key] = regs
self._last_processed.clear()
+ self._record_pre_history_change()
self._update_raw_plot()
self._trigger_preview()
@@ -5160,6 +5741,7 @@ def _assign_pending_box_to_section(self) -> None:
})
sections.sort(key=lambda sec: float(sec.get("start", 0.0)))
self._sections_by_key[key] = sections
+ self._record_pre_history_change()
self._show_status_message(f"Section added: {region[0]:.3f}s to {region[1]:.3f}s")
def _show_box_selection_context_menu(self) -> None:
@@ -5200,7 +5782,8 @@ def _contains(target: Tuple[float, float], arr: List[Tuple[float, float]]) -> bo
prev_ignore = self._manual_exclude_by_key.get(key, [])
self._manual_regions_by_key[key] = self._merge_regions_with_window(prev_manual, manual_add, start_s, end_s)
self._manual_exclude_by_key[key] = self._merge_regions_with_window(prev_ignore, manual_ignore, start_s, end_s)
- self._trigger_preview()
+ self._record_pre_history_change()
+ self._trigger_preview(preserve_view=True)
def _toggle_artifacts_panel(self) -> None:
if self._use_pg_dockarea_pre_layout:
@@ -5310,6 +5893,56 @@ def _remember_export_dir(self, out_dir: str, origin_dir: str) -> None:
except Exception:
pass
+ def _process_trial_for_export(
+ self,
+ trial: LoadedTrial,
+ params: ProcessingParams,
+ export_selection: ExportSelection,
+ manual_regions_sec: List[Tuple[float, float]],
+ manual_exclude_regions_sec: List[Tuple[float, float]],
+ ) -> ProcessedTrial:
+ modes: List[str] = []
+ if export_selection.output:
+ for mode in export_selection.output_modes or [params.output_mode]:
+ mode = str(mode or "").strip()
+ if mode in OUTPUT_MODES and mode not in modes:
+ modes.append(mode)
+ if not modes:
+ modes = [params.output_mode if params.output_mode in OUTPUT_MODES else OUTPUT_MODES[0]]
+
+ primary = params.output_mode if params.output_mode in modes else modes[0]
+ ordered_modes = [primary] + [mode for mode in modes if mode != primary]
+ base_processed: Optional[ProcessedTrial] = None
+ outputs: Dict[str, np.ndarray] = {}
+
+ for mode in ordered_modes:
+ mode_params = ProcessingParams.from_dict(params.to_dict())
+ mode_params.output_mode = mode
+ processed = self.processor.process_trial(
+ trial=trial,
+ params=mode_params,
+ manual_regions_sec=manual_regions_sec,
+ manual_exclude_regions_sec=manual_exclude_regions_sec,
+ preview_mode=False,
+ )
+ if base_processed is None:
+ base_processed = processed
+ if processed.output is not None:
+ outputs[str(processed.output_label or mode)] = np.asarray(processed.output, float)
+
+ if base_processed is None:
+ fallback_params = ProcessingParams.from_dict(params.to_dict())
+ base_processed = self.processor.process_trial(
+ trial=trial,
+ params=fallback_params,
+ manual_regions_sec=manual_regions_sec,
+ manual_exclude_regions_sec=manual_exclude_regions_sec,
+ preview_mode=False,
+ )
+ if export_selection.output:
+ base_processed.outputs = outputs
+ return base_processed
+
def _export_selected_or_all(self) -> None:
selected = self._selected_paths()
if not selected:
@@ -5317,28 +5950,41 @@ def _export_selected_or_all(self) -> None:
if not selected:
return
+ auto_export = bool(self.param_panel.auto_export_enabled())
origin_dir = self._export_origin_dir(selected)
- start_dir = self._export_start_dir(selected)
- out_dir = QtWidgets.QFileDialog.getExistingDirectory(self, "Select export folder", start_dir)
- if not out_dir:
- return
- self._remember_export_dir(out_dir, origin_dir)
+ out_dir = ""
+ if not auto_export:
+ start_dir = self._export_start_dir(selected)
+ out_dir = QtWidgets.QFileDialog.getExistingDirectory(self, "Select export folder", start_dir)
+ if not out_dir:
+ return
+ self._remember_export_dir(out_dir, origin_dir)
params = self.param_panel.get_params()
export_selection = self.param_panel.export_selection()
- export_channel_names = self.param_panel.export_channel_names()
+ export_channel_names = [] if auto_export else self.param_panel.export_channel_names()
export_trigger_names = self.param_panel.export_trigger_names()
- # Process/export each selected file, for the currently selected channel.
+ # Process/export each selected file. Auto export writes beside each source
+ # file and intentionally exports every analog channel with the same params.
n_total = 0
+ exported_dirs = set()
for path in selected:
doric = self._loaded_files.get(path)
if not doric:
continue
- channels = [name for name in export_channel_names if name in doric.channels]
- if not channels:
- fallback = self._current_channel if (self._current_channel in doric.channels) else (doric.channels[0] if doric.channels else None)
- channels = [fallback] if fallback else []
+ if auto_export:
+ channels = list(doric.channels)
+ else:
+ channels = [name for name in export_channel_names if name in doric.channels]
+ if not channels:
+ fallback = self._current_channel if (self._current_channel in doric.channels) else (doric.channels[0] if doric.channels else None)
+ channels = [fallback] if fallback else []
+ path_out_dir = out_dir
+ if auto_export:
+ path_out_dir = os.path.dirname(path)
+ if not path_out_dir or not os.path.isdir(path_out_dir):
+ path_out_dir = origin_dir if origin_dir and os.path.isdir(origin_dir) else os.getcwd()
dio_names = [name for name in export_trigger_names if name in doric.trigger_by_name]
if not export_selection.dio:
dio_names = [None]
@@ -5373,10 +6019,11 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None:
stem = safe_stem_from_metadata(path, ch, meta)
if suffix:
stem = f"{stem}_{suffix}"
- csv_path = os.path.join(out_dir, f"{stem}.csv")
- h5_path = os.path.join(out_dir, f"{stem}.h5")
+ csv_path = os.path.join(path_out_dir, f"{stem}.csv")
+ h5_path = os.path.join(path_out_dir, f"{stem}.h5")
export_processed_csv(csv_path, proc, metadata=meta, selection=export_selection)
export_processed_h5(h5_path, proc, metadata=meta, selection=export_selection)
+ exported_dirs.add(path_out_dir)
n_total += 1
try:
@@ -5388,21 +6035,21 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None:
if sec_trial is None:
continue
sec_params = ProcessingParams.from_dict(sec.get("params", {})) if isinstance(sec.get("params"), dict) else params
- processed = self.processor.process_trial(
+ processed = self._process_trial_for_export(
trial=sec_trial,
params=sec_params,
+ export_selection=export_selection,
manual_regions_sec=manual,
manual_exclude_regions_sec=manual_exclude,
- preview_mode=False,
)
_export_one(processed, suffix=f"sec{i}_{s0:.2f}_{s1:.2f}")
else:
- processed = self.processor.process_trial(
+ processed = self._process_trial_for_export(
trial=trial,
params=params,
+ export_selection=export_selection,
manual_regions_sec=manual,
manual_exclude_regions_sec=manual_exclude,
- preview_mode=False,
)
_export_one(processed)
except Exception as e:
@@ -5411,7 +6058,16 @@ def _export_one(proc: ProcessedTrial, suffix: str = "") -> None:
"Export error",
f"Failed export:\n{path} [{ch}] [{primary_trigger or 'no DIO'}]\n\n{e}",
)
- self._show_status_message(f"Export complete: {n_total} recording(s) written to {out_dir}")
+ if auto_export:
+ if len(exported_dirs) == 1:
+ target = next(iter(exported_dirs))
+ elif exported_dirs:
+ target = f"{len(exported_dirs)} source folders"
+ else:
+ target = "source folders"
+ else:
+ target = out_dir
+ self._show_status_message(f"Export complete: {n_total} recording(s) written to {target}")
# optional: update post tab list by loading exported results? (user can load later)
@@ -5897,6 +6553,8 @@ def main() -> None:
pg.setConfigOptions(antialias=False)
smoke_test = str(os.environ.get("PYBER_SMOKE_TEST", "")).strip().lower() in {"1", "true", "yes", "on"}
app = QtWidgets.QApplication([])
+ apply_app_palette(app, "dark")
+ spinbox_scrubber = install_spinbox_scrubbers(app)
icon_path = _pyber_icon_path()
try:
if os.path.isfile(icon_path):
@@ -5917,6 +6575,7 @@ def main() -> None:
except Exception:
splash = None
w = MainWindow()
+ spinbox_scrubber.scan(w)
if smoke_test:
try:
diff --git a/pyBer/numeric_controls.py b/pyBer/numeric_controls.py
new file mode 100644
index 0000000..1321445
--- /dev/null
+++ b/pyBer/numeric_controls.py
@@ -0,0 +1,207 @@
+from __future__ import annotations
+
+from typing import Optional
+
+from PySide6 import QtCore, QtGui, QtWidgets
+
+
+def _event_global_pos(event: QtCore.QEvent) -> QtCore.QPoint:
+ try:
+ return event.globalPosition().toPoint() # Qt 6
+ except Exception:
+ try:
+ return event.globalPos() # Qt 5 compatibility for older shims
+ except Exception:
+ return QtCore.QPoint()
+
+
+def _event_local_pos(event: QtCore.QEvent) -> QtCore.QPoint:
+ try:
+ return event.position().toPoint() # Qt 6
+ except Exception:
+ try:
+ return event.pos()
+ except Exception:
+ return QtCore.QPoint()
+
+
+class SpinBoxScrubber(QtCore.QObject):
+ """Turn spin boxes into arrowless, draggable numeric controls.
+
+ Users can still click into the field and type exact values. Dragging left/right
+ changes the value using the spin box's native stepping, so existing signal
+ wiring and validation continue to work.
+ """
+
+ _CONFIGURED_PROP = "_pyber_spin_scrubber_configured"
+
+ def __init__(self, parent: Optional[QtCore.QObject] = None) -> None:
+ super().__init__(parent)
+ self._press_spin: Optional[QtWidgets.QAbstractSpinBox] = None
+ self._press_pos = QtCore.QPoint()
+ self._press_local_pos = QtCore.QPoint()
+ self._last_steps = 0
+ self._dragging = False
+ self._override_cursor = False
+
+ def scan(self, root: QtCore.QObject) -> None:
+ if isinstance(root, QtWidgets.QAbstractSpinBox):
+ self._configure_spinbox(root)
+ if isinstance(root, QtWidgets.QWidget):
+ for spin in root.findChildren(QtWidgets.QAbstractSpinBox):
+ self._configure_spinbox(spin)
+
+ def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
+ etype = event.type()
+ if etype == QtCore.QEvent.Type.Show:
+ self._configure_object_tree(obj)
+
+ spin = self._spinbox_for_object(obj)
+ if spin is None:
+ return False
+ if etype in (
+ QtCore.QEvent.Type.MouseButtonPress,
+ QtCore.QEvent.Type.Wheel,
+ QtCore.QEvent.Type.FocusIn,
+ QtCore.QEvent.Type.KeyPress,
+ QtCore.QEvent.Type.Show,
+ ):
+ self._configure_spinbox(spin)
+ elif not bool(spin.property(self._CONFIGURED_PROP)):
+ return False
+
+ if etype == QtCore.QEvent.Type.MouseButtonPress:
+ if not self._left_button_event(event) or not spin.isEnabled():
+ return False
+ self._press_spin = spin
+ self._press_pos = _event_global_pos(event)
+ self._press_local_pos = _event_local_pos(event)
+ self._last_steps = 0
+ self._dragging = False
+ return False
+
+ if etype == QtCore.QEvent.Type.MouseMove and self._press_spin is spin:
+ if not self._left_button_held(event):
+ return False
+ dx_global = _event_global_pos(event).x() - self._press_pos.x()
+ dx_local = _event_local_pos(event).x() - self._press_local_pos.x()
+ dx = dx_global if abs(dx_global) >= abs(dx_local) else dx_local
+ if not self._dragging:
+ if abs(dx) < 5:
+ return False
+ self._dragging = True
+ spin.setFocus(QtCore.Qt.FocusReason.MouseFocusReason)
+ self._set_override_cursor(QtCore.Qt.CursorShape.SizeHorCursor)
+
+ steps = self._steps_from_drag(dx, event)
+ delta = steps - self._last_steps
+ if delta:
+ spin.stepBy(delta)
+ self._last_steps = steps
+ return True
+
+ if etype == QtCore.QEvent.Type.MouseButtonRelease and self._press_spin is spin:
+ was_dragging = self._dragging
+ self._press_spin = None
+ self._last_steps = 0
+ self._dragging = False
+ self._restore_override_cursor()
+ return was_dragging
+
+ if etype == QtCore.QEvent.Type.Leave and self._press_spin is spin and not self._left_button_held(event):
+ self._press_spin = None
+ self._last_steps = 0
+ self._dragging = False
+ self._restore_override_cursor()
+
+ return False
+
+ def _configure_object_tree(self, obj: QtCore.QObject) -> None:
+ spin = self._spinbox_for_object(obj)
+ if spin is not None:
+ self._configure_spinbox(spin)
+ return
+ if isinstance(obj, QtWidgets.QWidget):
+ for child in obj.findChildren(QtWidgets.QAbstractSpinBox):
+ self._configure_spinbox(child)
+
+ def _configure_spinbox(self, spin: QtWidgets.QAbstractSpinBox) -> None:
+ if bool(spin.property(self._CONFIGURED_PROP)):
+ return
+ spin.setProperty(self._CONFIGURED_PROP, True)
+ spin.setButtonSymbols(QtWidgets.QAbstractSpinBox.ButtonSymbols.NoButtons)
+ spin.setKeyboardTracking(False)
+ spin.setAccelerated(True)
+ spin.setCursor(QtCore.Qt.CursorShape.SizeHorCursor)
+ line_edit = spin.lineEdit()
+ if line_edit is not None:
+ line_edit.setCursor(QtCore.Qt.CursorShape.SizeHorCursor)
+ line_edit.setTextMargins(1, 0, 1, 0)
+ if isinstance(spin, QtWidgets.QDoubleSpinBox):
+ try:
+ spin.setStepType(QtWidgets.QAbstractSpinBox.StepType.AdaptiveDecimalStepType)
+ except Exception:
+ pass
+ tip = spin.toolTip().strip()
+ scrub_tip = "Drag left/right to adjust. Type for an exact value. Shift = faster, Ctrl = finer."
+ if scrub_tip not in tip:
+ spin.setToolTip(f"{tip}\n{scrub_tip}" if tip else scrub_tip)
+
+ def _spinbox_for_object(self, obj: QtCore.QObject) -> Optional[QtWidgets.QAbstractSpinBox]:
+ if isinstance(obj, QtWidgets.QAbstractSpinBox):
+ return obj
+ parent = obj.parent() if isinstance(obj, QtCore.QObject) else None
+ while parent is not None:
+ if isinstance(parent, QtWidgets.QAbstractSpinBox):
+ return parent
+ parent = parent.parent()
+ return None
+
+ def _steps_from_drag(self, dx: int, event: QtCore.QEvent) -> int:
+ pixels_per_step = 12.0
+ try:
+ mods = event.modifiers()
+ except Exception:
+ mods = QtCore.Qt.KeyboardModifier.NoModifier
+ if mods & QtCore.Qt.KeyboardModifier.ShiftModifier:
+ pixels_per_step = 5.0
+ elif mods & QtCore.Qt.KeyboardModifier.ControlModifier:
+ pixels_per_step = 28.0
+ return int(dx / pixels_per_step)
+
+ def _left_button_event(self, event: QtCore.QEvent) -> bool:
+ try:
+ return event.button() == QtCore.Qt.MouseButton.LeftButton
+ except Exception:
+ return False
+
+ def _left_button_held(self, event: QtCore.QEvent) -> bool:
+ try:
+ return bool(event.buttons() & QtCore.Qt.MouseButton.LeftButton)
+ except Exception:
+ return False
+
+ def _set_override_cursor(self, cursor: QtCore.Qt.CursorShape) -> None:
+ if self._override_cursor:
+ return
+ QtWidgets.QApplication.setOverrideCursor(QtGui.QCursor(cursor))
+ self._override_cursor = True
+
+ def _restore_override_cursor(self) -> None:
+ if not self._override_cursor:
+ return
+ try:
+ QtWidgets.QApplication.restoreOverrideCursor()
+ except Exception:
+ pass
+ self._override_cursor = False
+
+
+def install_spinbox_scrubbers(app: QtWidgets.QApplication) -> SpinBoxScrubber:
+ existing = getattr(app, "_pyber_spinbox_scrubber", None)
+ if isinstance(existing, SpinBoxScrubber):
+ return existing
+ scrubber = SpinBoxScrubber(app)
+ app.installEventFilter(scrubber)
+ setattr(app, "_pyber_spinbox_scrubber", scrubber)
+ return scrubber
diff --git a/pyBer/onboarding.py b/pyBer/onboarding.py
new file mode 100644
index 0000000..37a18fc
--- /dev/null
+++ b/pyBer/onboarding.py
@@ -0,0 +1,910 @@
+# onboarding.py
+"""
+User-experience polish layer for pyBer:
+
+* TutorialOverlay - first-run guided walkthrough with highlight cutout, callouts,
+ Next/Back/Skip controls. Reopen via F1 / Help menu.
+* ToastManager - lightweight non-blocking notifications (info/warn/error)
+ stacked top-right, click to dismiss.
+* PreferencesDialog- consolidates theme, autosave, default kernel window, behavior
+ defaults, and shows the keyboard cheat sheet.
+* register_shortcuts(window) - installs a wide set of global keyboard shortcuts.
+* attach_dirty_title(window) - shows "*" in window title while project is dirty.
+
+This module is intentionally self-contained: it touches no analysis code
+and never raises if the host window lacks an optional attribute.
+"""
+from __future__ import annotations
+
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from PySide6 import QtCore, QtGui, QtWidgets
+
+
+# ============================================================================
+# Toast notifications
+# ============================================================================
+
+_TOAST_QSS = {
+ "info": "background: #1f2a3a; color: #e9f0fb; border: 1px solid #355080;",
+ "warn": "background: #3a2d1d; color: #fde6c8; border: 1px solid #8a6a3a;",
+ "error": "background: #3b1f25; color: #ffd6dc; border: 1px solid #8a3949;",
+ "ok": "background: #1c2e22; color: #d4f4dc; border: 1px solid #2f7a4a;",
+}
+
+
+class _Toast(QtWidgets.QFrame):
+ closed = QtCore.Signal(object)
+
+ def __init__(self, parent: QtWidgets.QWidget, text: str, severity: str, timeout_ms: int):
+ super().__init__(parent)
+ self.setObjectName("pyberToast")
+ self.setAttribute(QtCore.Qt.WidgetAttribute.WA_DeleteOnClose, True)
+ self.setStyleSheet(
+ "QFrame#pyberToast { %s border-radius: 8px; padding: 8px 12px; } "
+ "QFrame#pyberToast QLabel { background: transparent; }"
+ % _TOAST_QSS.get(severity, _TOAST_QSS["info"])
+ )
+ lay = QtWidgets.QHBoxLayout(self)
+ lay.setContentsMargins(10, 7, 10, 7)
+ lay.setSpacing(8)
+ icon = {"info": "i", "warn": "!", "error": "x", "ok": "v"}.get(severity, "i")
+ badge = QtWidgets.QLabel(icon)
+ badge.setFixedSize(20, 20)
+ badge.setAlignment(QtCore.Qt.AlignmentFlag.AlignCenter)
+ badge.setStyleSheet("background: rgba(255,255,255,0.08); border-radius: 10px; font-weight: 700;")
+ lay.addWidget(badge)
+ self._label = QtWidgets.QLabel(str(text))
+ self._label.setWordWrap(True)
+ self._label.setMaximumWidth(360)
+ lay.addWidget(self._label, 1)
+ self.setMinimumWidth(220)
+ self.setMaximumWidth(420)
+ self.setCursor(QtCore.Qt.CursorShape.PointingHandCursor)
+ if timeout_ms > 0:
+ QtCore.QTimer.singleShot(int(timeout_ms), self._dismiss)
+
+ def mousePressEvent(self, event: QtGui.QMouseEvent) -> None:
+ self._dismiss()
+ super().mousePressEvent(event)
+
+ def _dismiss(self) -> None:
+ try:
+ self.closed.emit(self)
+ finally:
+ self.close()
+
+
+class ToastManager(QtCore.QObject):
+ """
+ Stacked top-right toast queue that follows the host window. Up to
+ `max_visible` toasts; older ones drop off when more arrive.
+ """
+
+ def __init__(self, window: QtWidgets.QMainWindow, max_visible: int = 4):
+ super().__init__(window)
+ self._window = window
+ self._max_visible = int(max_visible)
+ self._toasts: List[_Toast] = []
+ self._margin = 14
+ self._spacing = 8
+ window.installEventFilter(self)
+
+ def post(self, text: str, severity: str = "info", timeout_ms: int = 5000) -> None:
+ if not text:
+ return
+ toast = _Toast(self._window, text, severity, timeout_ms)
+ toast.closed.connect(self._on_closed)
+ self._toasts.append(toast)
+ # Cap visible count.
+ while len(self._toasts) > self._max_visible:
+ old = self._toasts.pop(0)
+ old.close()
+ toast.show()
+ toast.adjustSize()
+ self._reflow()
+
+ # Convenience wrappers.
+ def info(self, text: str, timeout_ms: int = 4000) -> None:
+ self.post(text, "info", timeout_ms)
+
+ def ok(self, text: str, timeout_ms: int = 3500) -> None:
+ self.post(text, "ok", timeout_ms)
+
+ def warn(self, text: str, timeout_ms: int = 6500) -> None:
+ self.post(text, "warn", timeout_ms)
+
+ def error(self, text: str, timeout_ms: int = 9000) -> None:
+ self.post(text, "error", timeout_ms)
+
+ def _on_closed(self, toast: _Toast) -> None:
+ try:
+ self._toasts.remove(toast)
+ except ValueError:
+ pass
+ self._reflow()
+
+ def _reflow(self) -> None:
+ if not self._window.isVisible():
+ return
+ rect = self._window.rect()
+ x_right = rect.right() - self._margin
+ y = rect.top() + self._margin
+ # Status bar/toolbar offset
+ try:
+ mw = self._window
+ if isinstance(mw, QtWidgets.QMainWindow) and mw.menuBar() is not None:
+ y += mw.menuBar().height() + 4
+ except Exception:
+ pass
+ for toast in self._toasts:
+ toast.adjustSize()
+ w = toast.width()
+ toast.move(x_right - w, y)
+ toast.raise_()
+ y += toast.height() + self._spacing
+
+ def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
+ if obj is self._window and event.type() in (
+ QtCore.QEvent.Type.Resize,
+ QtCore.QEvent.Type.Move,
+ QtCore.QEvent.Type.WindowStateChange,
+ ):
+ self._reflow()
+ return False
+
+
+# ============================================================================
+# Tutorial overlay
+# ============================================================================
+
+
+class TutorialStep:
+ """One step of the guided tutorial."""
+
+ def __init__(
+ self,
+ title: str,
+ body: str,
+ target_resolver: Optional[Callable[[QtWidgets.QWidget], Optional[QtWidgets.QWidget]]] = None,
+ before: Optional[Callable[[QtWidgets.QWidget], None]] = None,
+ ):
+ self.title = title
+ self.body = body
+ self.target_resolver = target_resolver
+ self.before = before # called before showing the step (e.g. switch tab)
+
+
+class TutorialOverlay(QtWidgets.QWidget):
+ """
+ Full-window translucent overlay. A "spotlight" cutout illuminates the
+ target widget, and a styled callout near it shows step text and
+ Back / Next / Skip controls. Esc skips. Arrow keys page through.
+ """
+
+ finished = QtCore.Signal()
+
+ def __init__(self, host: QtWidgets.QMainWindow, steps: List[TutorialStep]):
+ super().__init__(host)
+ self._host = host
+ self._steps = list(steps)
+ self._index = 0
+ self.setAttribute(QtCore.Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
+ self.setAttribute(QtCore.Qt.WidgetAttribute.WA_StyledBackground, True)
+ self.setAttribute(QtCore.Qt.WidgetAttribute.WA_NoSystemBackground, True)
+ self.setFocusPolicy(QtCore.Qt.FocusPolicy.StrongFocus)
+ self.setMouseTracking(True)
+ self._target_rect = QtCore.QRect()
+ self._build_callout()
+ host.installEventFilter(self)
+
+ # --- internal UI ---
+
+ def _build_callout(self) -> None:
+ self._callout = QtWidgets.QFrame(self)
+ self._callout.setObjectName("tutorialCallout")
+ self._callout.setStyleSheet(
+ "QFrame#tutorialCallout { background: #14202f; color: #e9f0fb; "
+ "border: 1px solid #2f8cff; border-radius: 12px; }"
+ "QFrame#tutorialCallout QLabel { background: transparent; color: #e9f0fb; }"
+ "QFrame#tutorialCallout QLabel#tutTitle { font-weight: 700; font-size: 12pt; }"
+ "QFrame#tutorialCallout QLabel#tutStep { color: #95a5c2; font-size: 8.5pt; }"
+ "QFrame#tutorialCallout QPushButton { background: #1c2a3e; color: #e9f0fb; "
+ "border: 1px solid #355080; border-radius: 6px; padding: 5px 12px; }"
+ "QFrame#tutorialCallout QPushButton:hover { background: #233553; }"
+ "QFrame#tutorialCallout QPushButton#tutPrimary { background: #2f8cff; "
+ "border: 1px solid #46a0ff; color: white; font-weight: 700; }"
+ "QFrame#tutorialCallout QPushButton#tutPrimary:hover { background: #46a0ff; }"
+ "QFrame#tutorialCallout QPushButton#tutSkip { color: #95a5c2; border-color: transparent; }"
+ )
+ lay = QtWidgets.QVBoxLayout(self._callout)
+ lay.setContentsMargins(18, 16, 18, 14)
+ lay.setSpacing(8)
+
+ self._lbl_step = QtWidgets.QLabel("Step 1 / N")
+ self._lbl_step.setObjectName("tutStep")
+ lay.addWidget(self._lbl_step)
+ self._lbl_title = QtWidgets.QLabel("Title")
+ self._lbl_title.setObjectName("tutTitle")
+ self._lbl_title.setWordWrap(True)
+ lay.addWidget(self._lbl_title)
+ self._lbl_body = QtWidgets.QLabel("Body")
+ self._lbl_body.setWordWrap(True)
+ self._lbl_body.setMinimumWidth(340)
+ self._lbl_body.setMaximumWidth(420)
+ lay.addWidget(self._lbl_body)
+
+ btn_row = QtWidgets.QHBoxLayout()
+ btn_row.setSpacing(6)
+ self._btn_skip = QtWidgets.QPushButton("Skip tutorial")
+ self._btn_skip.setObjectName("tutSkip")
+ self._btn_skip.clicked.connect(self._end)
+ btn_row.addWidget(self._btn_skip)
+ btn_row.addStretch(1)
+ self._btn_back = QtWidgets.QPushButton("◀ Back")
+ self._btn_back.clicked.connect(self._prev)
+ btn_row.addWidget(self._btn_back)
+ self._btn_next = QtWidgets.QPushButton("Next ▶")
+ self._btn_next.setObjectName("tutPrimary")
+ self._btn_next.setDefault(True)
+ self._btn_next.clicked.connect(self._next)
+ btn_row.addWidget(self._btn_next)
+ lay.addLayout(btn_row)
+ self._callout.adjustSize()
+
+ # --- lifecycle ---
+
+ def start(self) -> None:
+ self.setGeometry(self._host.rect())
+ self.show()
+ self.raise_()
+ self.setFocus(QtCore.Qt.FocusReason.OtherFocusReason)
+ self._render_step()
+
+ def _end(self) -> None:
+ self._host.removeEventFilter(self)
+ self.finished.emit()
+ self.close()
+ self.deleteLater()
+
+ def _next(self) -> None:
+ if self._index >= len(self._steps) - 1:
+ self._end()
+ return
+ self._index += 1
+ self._render_step()
+
+ def _prev(self) -> None:
+ if self._index <= 0:
+ return
+ self._index -= 1
+ self._render_step()
+
+ # --- rendering ---
+
+ def _render_step(self) -> None:
+ step = self._steps[self._index]
+ if step.before is not None:
+ try:
+ step.before(self._host)
+ except Exception:
+ pass
+ n = len(self._steps)
+ self._lbl_step.setText(f"Step {self._index + 1} / {n}")
+ self._lbl_title.setText(step.title)
+ self._lbl_body.setText(step.body)
+ self._btn_back.setEnabled(self._index > 0)
+ self._btn_next.setText("Got it!" if self._index == n - 1 else "Next ▶")
+
+ target = None
+ if step.target_resolver is not None:
+ try:
+ target = step.target_resolver(self._host)
+ except Exception:
+ target = None
+ self._target_rect = self._compute_target_rect(target)
+ self._position_callout()
+ self.update()
+
+ def _compute_target_rect(self, target: Optional[QtWidgets.QWidget]) -> QtCore.QRect:
+ if target is None or not target.isVisible():
+ return QtCore.QRect()
+ # Map the target widget rect into the overlay's coordinate space.
+ top_left = target.mapTo(self._host, QtCore.QPoint(0, 0))
+ rect = QtCore.QRect(top_left, target.size())
+ rect = rect.adjusted(-6, -6, 6, 6)
+ return rect
+
+ def _position_callout(self) -> None:
+ margin = 16
+ host_rect = self.rect()
+ self._callout.adjustSize()
+ co_w = self._callout.width()
+ co_h = self._callout.height()
+ if self._target_rect.isEmpty():
+ x = (host_rect.width() - co_w) // 2
+ y = (host_rect.height() - co_h) // 2
+ self._callout.move(x, y)
+ return
+ # Try to place to the right of the target, fallback below, then left, then above.
+ candidates = [
+ QtCore.QPoint(self._target_rect.right() + margin, self._target_rect.top()),
+ QtCore.QPoint(self._target_rect.left(), self._target_rect.bottom() + margin),
+ QtCore.QPoint(self._target_rect.left() - co_w - margin, self._target_rect.top()),
+ QtCore.QPoint(self._target_rect.left(), self._target_rect.top() - co_h - margin),
+ ]
+ chosen = candidates[0]
+ for cand in candidates:
+ if (
+ cand.x() >= margin and cand.y() >= margin
+ and cand.x() + co_w + margin <= host_rect.width()
+ and cand.y() + co_h + margin <= host_rect.height()
+ ):
+ chosen = cand
+ break
+ chosen.setX(max(margin, min(host_rect.width() - co_w - margin, chosen.x())))
+ chosen.setY(max(margin, min(host_rect.height() - co_h - margin, chosen.y())))
+ self._callout.move(chosen)
+
+ def paintEvent(self, event: QtGui.QPaintEvent) -> None:
+ painter = QtGui.QPainter(self)
+ painter.setRenderHint(QtGui.QPainter.RenderHint.Antialiasing, True)
+ # Backdrop
+ painter.fillRect(self.rect(), QtGui.QColor(8, 12, 22, 190))
+ if not self._target_rect.isEmpty():
+ # Punch a soft "spotlight" on the target.
+ path = QtGui.QPainterPath()
+ path.addRect(QtCore.QRectF(self.rect()))
+ spot = QtGui.QPainterPath()
+ spot.addRoundedRect(QtCore.QRectF(self._target_rect), 8, 8)
+ path = path.subtracted(spot)
+ painter.fillPath(path, QtGui.QColor(8, 12, 22, 215))
+ # Outline the target.
+ pen = QtGui.QPen(QtGui.QColor("#2f8cff"))
+ pen.setWidth(2)
+ painter.setPen(pen)
+ painter.setBrush(QtCore.Qt.BrushStyle.NoBrush)
+ painter.drawRoundedRect(self._target_rect, 8, 8)
+ # Optional: dashed glow inside.
+ glow = QtGui.QPen(QtGui.QColor(47, 140, 255, 120))
+ glow.setWidth(1)
+ glow.setStyle(QtCore.Qt.PenStyle.DashLine)
+ painter.setPen(glow)
+ painter.drawRoundedRect(self._target_rect.adjusted(3, 3, -3, -3), 6, 6)
+
+ def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
+ key = event.key()
+ if key in (QtCore.Qt.Key.Key_Escape,):
+ self._end()
+ elif key in (QtCore.Qt.Key.Key_Right, QtCore.Qt.Key.Key_PageDown, QtCore.Qt.Key.Key_Space, QtCore.Qt.Key.Key_Return, QtCore.Qt.Key.Key_Enter):
+ self._next()
+ elif key in (QtCore.Qt.Key.Key_Left, QtCore.Qt.Key.Key_PageUp, QtCore.Qt.Key.Key_Backspace):
+ self._prev()
+ else:
+ super().keyPressEvent(event)
+
+ def eventFilter(self, obj: QtCore.QObject, event: QtCore.QEvent) -> bool:
+ if obj is self._host and event.type() in (
+ QtCore.QEvent.Type.Resize,
+ QtCore.QEvent.Type.Move,
+ ):
+ self.setGeometry(self._host.rect())
+ self._render_step()
+ return False
+
+
+def build_default_tutorial(window: QtWidgets.QMainWindow) -> List[TutorialStep]:
+ """Default first-run tutorial covering the main workflow regions."""
+
+ def _switch_pre(_w: QtWidgets.QMainWindow) -> None:
+ try:
+ window.tabs.setCurrentIndex(0)
+ except Exception:
+ pass
+
+ def _switch_post(_w: QtWidgets.QMainWindow) -> None:
+ try:
+ window.tabs.setCurrentIndex(1)
+ except Exception:
+ pass
+
+ def _resolve_tabs(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]:
+ return getattr(window, "tabs", None)
+
+ def _resolve_pre_files(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]:
+ return getattr(window, "file_panel", None)
+
+ def _resolve_status(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]:
+ return getattr(window, "_status_bar", None)
+
+ def _resolve_post(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]:
+ return getattr(window, "post_tab", None)
+
+ def _resolve_temporal(_w: QtWidgets.QMainWindow) -> Optional[QtWidgets.QWidget]:
+ post = getattr(window, "post_tab", None)
+ if post is None:
+ return None
+ return getattr(post, "section_temporal", None)
+
+ return [
+ TutorialStep(
+ "Welcome to pyBer",
+ "pyBer takes raw fiber-photometry recordings (Doric / CSV) all the way to "
+ "PSTH, behavior alignment, and GLM/FLMM analysis.\n\n"
+ "Use ◀ ▶ or arrow keys to step through. Press F1 anytime to reopen this tour.",
+ target_resolver=None,
+ ),
+ TutorialStep(
+ "Two main tabs: Preprocessing -> Postprocessing",
+ "Preprocessing handles raw signal cleanup, artifact removal and export to "
+ "processed traces. Postprocessing consumes those traces for PSTH, peak/event "
+ "metrics, behavior alignment, and modeling. The flow is left to right.",
+ target_resolver=_resolve_tabs,
+ before=_switch_pre,
+ ),
+ TutorialStep(
+ "Step 1 - Drop your files here",
+ "Drag .doric, .csv or .h5 files (or a whole folder) onto the file queue. "
+ "Use Ctrl+O to browse, Ctrl+Shift+O for a folder, Delete to remove a selection.",
+ target_resolver=_resolve_pre_files,
+ before=_switch_pre,
+ ),
+ TutorialStep(
+ "Step 2 - Run quality check + export",
+ "Ctrl+Q runs QC on the active recording, Ctrl+Shift+Q does a batch QC, "
+ "Ctrl+E exports the current selection. Ctrl+Z / Ctrl+Y undo/redo "
+ "preprocessing actions.",
+ target_resolver=_resolve_status,
+ before=_switch_pre,
+ ),
+ TutorialStep(
+ "Step 3 - Switch to Postprocessing",
+ "Once you've exported processed traces, hop to the Postprocessing tab. "
+ "Drag the .csv/.h5 outputs onto its file dropzone, or open a saved project.",
+ target_resolver=_resolve_post,
+ before=_switch_post,
+ ),
+ TutorialStep(
+ "Active file vs Group",
+ "Postprocessing shows a single recording in 'Individual' mode or a group of "
+ "animals in 'Group' mode. Use Ctrl+G to toggle. Ctrl+Left / Ctrl+Right step "
+ "through loaded animals.",
+ target_resolver=_resolve_post,
+ before=_switch_post,
+ ),
+ TutorialStep(
+ "Temporal Modeling",
+ "The 'T' panel fits a Continuous GLM or trial-level FLMM. Choose a Scope:\n"
+ "- Active file (single)\n"
+ "- All loaded (concatenated)\n"
+ "- Per-file batch + group\n\n"
+ "Press Ctrl+Shift+F to fit. The Group tab aggregates per-file kernels "
+ "with mean +/- SEM across animals.",
+ target_resolver=_resolve_temporal,
+ before=_switch_post,
+ ),
+ TutorialStep(
+ "Keyboard shortcuts",
+ "Press Ctrl+/ anytime for the full cheat sheet. Highlights:\n"
+ "- F1 Help / replay tour Ctrl+, Preferences\n"
+ "- Ctrl+S Save project Ctrl+Shift+S Save as\n"
+ "- Ctrl+1 Preprocessing Ctrl+2 Postprocessing\n"
+ "- Ctrl+0 Reset focused plot view\n"
+ "- Esc Cancel current operation",
+ target_resolver=None,
+ ),
+ TutorialStep(
+ "You're set",
+ "That's the whirlwind tour. Tooltips on every input fill in the rest. "
+ "If something fails the toast in the corner shows what happened - click "
+ "it to dismiss.\n\nHappy analyzing.",
+ target_resolver=None,
+ ),
+ ]
+
+
+# ============================================================================
+# Preferences dialog
+# ============================================================================
+
+
+class PreferencesDialog(QtWidgets.QDialog):
+ """
+ Compact preferences dialog. Reads/writes via QSettings so the values
+ survive restarts. Apply emits no signal; consumers (theme button, etc.)
+ pick up changes the next time they read the corresponding key.
+ """
+
+ KEYS = {
+ "theme": "app/theme", # "dark" | "light"
+ "autosave": "app/autosave_enabled", # bool
+ "autosave_min": "app/autosave_minutes", # int
+ "kernel_pre": "temporal_modeling/kernel_pre",
+ "kernel_post": "temporal_modeling/kernel_post",
+ "show_tutorial": "onboarding/show_on_startup",
+ "toast_timeout": "ui/toast_timeout_ms",
+ }
+
+ def __init__(self, parent: QtWidgets.QWidget, settings: QtCore.QSettings):
+ super().__init__(parent)
+ self.setWindowTitle("Preferences")
+ self.setModal(True)
+ self._settings = settings
+ self.resize(540, 420)
+
+ tabs = QtWidgets.QTabWidget(self)
+
+ # ---- Appearance ----
+ appearance = QtWidgets.QWidget()
+ a = QtWidgets.QFormLayout(appearance)
+ a.setContentsMargins(16, 16, 16, 16)
+ self.combo_theme = QtWidgets.QComboBox()
+ self.combo_theme.addItem("Dark", "dark")
+ self.combo_theme.addItem("Light", "light")
+ a.addRow("Theme", self.combo_theme)
+ self.spin_toast = QtWidgets.QSpinBox()
+ self.spin_toast.setRange(1500, 30000)
+ self.spin_toast.setSingleStep(500)
+ self.spin_toast.setSuffix(" ms")
+ a.addRow("Toast default duration", self.spin_toast)
+ self.chk_show_tutorial = QtWidgets.QCheckBox(
+ "Show first-run tutorial on next launch"
+ )
+ a.addRow("Onboarding", self.chk_show_tutorial)
+ tabs.addTab(appearance, "Appearance")
+
+ # ---- Defaults ----
+ defaults = QtWidgets.QWidget()
+ d = QtWidgets.QFormLayout(defaults)
+ d.setContentsMargins(16, 16, 16, 16)
+ self.spin_kernel_pre = QtWidgets.QDoubleSpinBox()
+ self.spin_kernel_pre.setRange(-30.0, 0.0)
+ self.spin_kernel_pre.setDecimals(2)
+ self.spin_kernel_pre.setSuffix(" s")
+ d.addRow("Default GLM kernel pre", self.spin_kernel_pre)
+ self.spin_kernel_post = QtWidgets.QDoubleSpinBox()
+ self.spin_kernel_post.setRange(0.1, 60.0)
+ self.spin_kernel_post.setDecimals(2)
+ self.spin_kernel_post.setSuffix(" s")
+ d.addRow("Default GLM kernel post", self.spin_kernel_post)
+ self.chk_autosave = QtWidgets.QCheckBox("Enable project autosave")
+ d.addRow("Autosave", self.chk_autosave)
+ self.spin_autosave = QtWidgets.QSpinBox()
+ self.spin_autosave.setRange(1, 60)
+ self.spin_autosave.setSuffix(" min")
+ d.addRow("Autosave interval", self.spin_autosave)
+ tabs.addTab(defaults, "Defaults")
+
+ # ---- Keyboard cheat sheet ----
+ keyboard = QtWidgets.QWidget()
+ k = QtWidgets.QVBoxLayout(keyboard)
+ k.setContentsMargins(16, 16, 16, 16)
+ cheat = QtWidgets.QTextBrowser()
+ cheat.setReadOnly(True)
+ cheat.setOpenExternalLinks(False)
+ cheat.setHtml(_keyboard_cheatsheet_html())
+ k.addWidget(cheat, 1)
+ tabs.addTab(keyboard, "Keyboard")
+
+ # ---- About ----
+ about = QtWidgets.QWidget()
+ ab = QtWidgets.QVBoxLayout(about)
+ ab.setContentsMargins(16, 16, 16, 16)
+ about_lbl = QtWidgets.QLabel(
+ "
Pipeline for raw photometry preprocessing, PSTH, behavior alignment, " + "and GLM/FLMM modeling.
" + "Bellone Lab toolkit.
" + ) + about_lbl.setWordWrap(True) + ab.addWidget(about_lbl) + ab.addStretch(1) + tabs.addTab(about, "About") + + # ---- Buttons ---- + button_box = QtWidgets.QDialogButtonBox( + QtWidgets.QDialogButtonBox.StandardButton.Ok + | QtWidgets.QDialogButtonBox.StandardButton.Cancel + | QtWidgets.QDialogButtonBox.StandardButton.Apply + ) + button_box.accepted.connect(self._on_accept) + button_box.rejected.connect(self.reject) + button_box.button(QtWidgets.QDialogButtonBox.StandardButton.Apply).clicked.connect(self._apply) + + layout = QtWidgets.QVBoxLayout(self) + layout.setContentsMargins(10, 10, 10, 10) + layout.addWidget(tabs) + layout.addWidget(button_box) + + self._load() + + # --- internals --- + + def _load(self) -> None: + s = self._settings + theme = str(s.value(self.KEYS["theme"], "dark") or "dark").lower() + idx = self.combo_theme.findData(theme) + if idx >= 0: + self.combo_theme.setCurrentIndex(idx) + self.spin_toast.setValue(int(s.value(self.KEYS["toast_timeout"], 5000) or 5000)) + self.chk_show_tutorial.setChecked(_to_bool(s.value(self.KEYS["show_tutorial"], True), True)) + self.spin_kernel_pre.setValue(float(s.value(self.KEYS["kernel_pre"], -1.0) or -1.0)) + self.spin_kernel_post.setValue(float(s.value(self.KEYS["kernel_post"], 3.0) or 3.0)) + self.chk_autosave.setChecked(_to_bool(s.value(self.KEYS["autosave"], True), True)) + self.spin_autosave.setValue(int(s.value(self.KEYS["autosave_min"], 5) or 5)) + + def _apply(self) -> None: + s = self._settings + s.setValue(self.KEYS["theme"], self.combo_theme.currentData()) + s.setValue(self.KEYS["toast_timeout"], int(self.spin_toast.value())) + s.setValue(self.KEYS["show_tutorial"], bool(self.chk_show_tutorial.isChecked())) + s.setValue(self.KEYS["kernel_pre"], float(self.spin_kernel_pre.value())) + s.setValue(self.KEYS["kernel_post"], float(self.spin_kernel_post.value())) + s.setValue(self.KEYS["autosave"], bool(self.chk_autosave.isChecked())) + s.setValue(self.KEYS["autosave_min"], int(self.spin_autosave.value())) + + def _on_accept(self) -> None: + self._apply() + self.accept() + + +def _to_bool(value: object, default: bool = False) -> bool: + if isinstance(value, bool): + return value + text = str(value or "").strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + return default + + +# ============================================================================ +# Keyboard shortcuts +# ============================================================================ + + +_GLOBAL_SHORTCUTS: List[Tuple[str, str, str]] = [ + # (sequence, callback_attr_or_method, description) + ("F1", "_show_tutorial_again", "Show / replay the tutorial"), + ("Ctrl+/", "_show_keyboard_cheatsheet", "Open the keyboard cheat sheet"), + ("Ctrl+,", "_open_preferences", "Open Preferences"), + ("Ctrl+1", "_focus_pre_tab", "Switch to Preprocessing tab"), + ("Ctrl+2", "_focus_post_tab", "Switch to Postprocessing tab"), + ("Ctrl+Tab", "_cycle_main_tab", "Cycle main tabs"), + ("Ctrl+0", "_reset_focused_plot_view", "Reset focused plot view"), + ("Ctrl+Right", "_step_active_file_next", "Next loaded file"), + ("Ctrl+Left", "_step_active_file_prev", "Previous loaded file"), + ("Ctrl+G", "_toggle_individual_group", "Toggle Individual / Group"), + ("Ctrl+Shift+F", "_fit_temporal_model", "Fit temporal model"), + ("Ctrl+Shift+B", "_fit_temporal_all_files", "Fit GLM on every file (batch)"), + ("F5", "_recompute_psth", "Recompute PSTH"), + ("Ctrl+Shift+E", "_run_postprocess_export", "Run Postprocessing export"), + ("Esc", "_cancel_current_operation", "Cancel current operation"), +] + + +def register_global_shortcuts(window: QtWidgets.QMainWindow) -> Dict[str, str]: + """ + Bind every shortcut in `_GLOBAL_SHORTCUTS` to `window` with application-wide + context. Resolves the callback by attribute name; if the host doesn't define + the attribute, the shortcut becomes a no-op (so partial wiring is safe). + + Returns a {sequence: description} dict so the cheat sheet can stay accurate. + """ + descriptions: Dict[str, str] = {} + for seq, attr, desc in _GLOBAL_SHORTCUTS: + descriptions[seq] = desc + sc = QtGui.QShortcut(QtGui.QKeySequence(seq), window) + sc.setContext(QtCore.Qt.ShortcutContext.ApplicationShortcut) + + def _make_handler(attr_name: str) -> Callable[[], None]: + def _handler() -> None: + # Don't fire when typing in a line edit / spin box. + w = QtWidgets.QApplication.focusWidget() + if isinstance(w, (QtWidgets.QLineEdit, QtWidgets.QPlainTextEdit, QtWidgets.QTextEdit)): + return + # QSpinBox / QDoubleSpinBox: allow Esc, Ctrl+Right, etc. to pass. + fn = getattr(window, attr_name, None) + if not callable(fn): + return + try: + fn() + except Exception: + pass + + return _handler + + sc.activated.connect(_make_handler(attr)) + # Keep a reference so it isn't GC'd. + if not hasattr(window, "_pyber_global_shortcuts"): + window._pyber_global_shortcuts = [] + window._pyber_global_shortcuts.append(sc) + return descriptions + + +def _keyboard_cheatsheet_html() -> str: + rows = [ + ("Application", []), + (None, [ + ("F1", "Help / replay tutorial"), + ("Ctrl+/", "Open this cheat sheet"), + ("Ctrl+,", "Preferences"), + ("Ctrl+1 / Ctrl+2", "Switch to Preprocessing / Postprocessing"), + ("Ctrl+Tab", "Cycle main tabs"), + ("Esc", "Cancel current operation"), + ]), + ("Files / project", []), + (None, [ + ("Ctrl+O", "Open files (Preprocessing)"), + ("Ctrl+Shift+O", "Open folder (Preprocessing)"), + ("Ctrl+S", "Save project / config"), + ("Ctrl+Shift+S", "Save as"), + ("Ctrl+L", "Load preprocessing config"), + ("Delete", "Remove selected files"), + ("Ctrl+Right / Ctrl+Left", "Next / previous loaded file"), + ("Ctrl+G", "Toggle Individual / Group"), + ]), + ("Preprocessing", []), + (None, [ + ("Ctrl+Q", "Run QC on active file"), + ("Ctrl+Shift+Q", "Batch QC"), + ("Ctrl+E", "Export current selection"), + ("Ctrl+K", "Toggle Artifacts panel"), + ("Ctrl+F", "Toggle Filtering panel"), + ("Ctrl+B", "Toggle Baseline panel"), + ("Ctrl+M", "Toggle Output panel"), + ("Ctrl+D", "Toggle Data panel"), + ("Ctrl+P", "Toggle parameter popups"), + ("Ctrl+Enter", "Trigger preview"), + ("A / C / S", "Assign pending box -> Artifact / Cut / Section"), + ]), + ("Postprocessing / Modeling", []), + (None, [ + ("F5", "Recompute PSTH"), + ("Ctrl+Shift+E", "Run postprocessing export"), + ("Ctrl+Shift+F", "Fit temporal model (current scope)"), + ("Ctrl+Shift+B", "Fit GLM on every file (batch)"), + ("Ctrl+0", "Reset focused plot view"), + ]), + ] + parts = [""] + for header, items in rows: + if header is not None: + parts.append(f"| {key_html} | {desc} |