A machine learning-based system for probabilistic position reconstruction using normalizing flows to model particle positions in physics experiments.
The module automatically:
- Loads the specified configuration
- Loads data with an extensible data loader system
- Preprocesses the data for training
- Constructs and trains a position reconstruction model
- Saves the trained model to the specified output directory
pip install -e .The module can be run directly from the command line:
python -m probabilistic_posrec --config config_file.toml --data-dir /path/to/data --output-name-suffix test_runThe primary way to use this module is to load a configuration file. This works better with the workflow of large experiments as configurations are likely expected to be rather stable and sharable. An example configuration file can be found in config.toml in the root directory of this repository.
However, it is possible to override the configuration file with the following command line arguments:
--config Path to the TOML configuration file (required)
--loader-type Type of data loader to use: h5 or simulation
--data-dir Directories containing data files (comma-separated)
--file-pattern Pattern to match data files
--random-sample Load a random sample of files instead of all files
--sample-size Number of files to randomly sample
--exclude-runs Comma-separated list of run numbers to exclude
--output-dir Directory to save trained models
--randomize-off-pmt Randomly turn off PMTs
--output-name-suffix Suffix string for output file names
--pmt-positions-file Path to CSV file with PMT positions (for simulation loader)
The system supports two types of data loaders:
Loads position reconstruction data from HDF5 files. This is the default loader and is suitable for simulated data from more complex simulations.
Generates simulated S2 patterns using a simplified light collection efficiency model. This is useful for development, testing, and evaluations.
To use the simulation loader:
python -m probabilistic_posrec --config config_file.toml --loader-type simulationThe simulation loader requires a CSV file with PMT positions, which can be specified in the configuration file or via the --pmt-positions-file command line argument.
The system supports two transformation types for unbounding the bounded position coordinates (x, y) to unbounded latent space for the normalizing flow:
Uses arctanh to map bounded position coordinates to unbounded space. This is the original implementation and is suitable for most use cases.
[model]
transform_type = "tanh"Uses a chi-squared based bijection (StandardNormalToUnitBall) that maps a standard normal distribution to uniform distribution on the unit ball. This transformation:
- Computes the chi-squared CDF of the squared magnitude ||x||²
- Transforms to the correct radial distribution via the power transformation
- Provides a probabilistically rigorous mapping between Gaussian and unit ball geometries
This transform requires the numerax package for computing the inverse incomplete gamma function (gammap_inverse), which is used in the inverse chi-squared CDF calculation.
[model]
transform_type = "unit_ball"Note: Multi-variable inference (n_dims > 2) only transforms the first 2 dimensions (positions). Additional target variables are passed through unchanged as they are assumed to already be unbounded.
The Loader class is designed to be extended for custom data loading. Here's an example of how to create a custom loader:
from probabilistic_posrec import Loader
from typing import Tuple
from jaxtyping import ArrayLike
class CustomLoader(Loader):
def load(self) -> None:
# Load data from your source
pass
def _get_hitpatterns_and_positions(self) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
# Return (hitpatterns, x_positions, y_positions)
pass
def _prepare_hitpatterns(self, hitpatterns: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
# Return (processed_hitpatterns, areas)
pass
def preprocess(self) -> None:
# Use common preprocessing workflow
self.preprocess_common()
def get_data_array(self) -> Tuple[ArrayLike, ArrayLike]:
# Return (train_x, train_cond)
return self.train_x, self.train_condConfiguration is managed via TOML files with several sections:
[model]
# Normalizing flow parameters
flow_layers = 5
nn_width = 128
nn_depth = 3
activation = "leaky_relu"
spline_knots = 5
spline_interval = 5
# Data transformation parameters
r_max = 66.4
buffer = 20.0
scale = 1.0
eps = 1e-7
log_area_scale = 10.0
# Transform type: "tanh" (default) or "unit_ball"
transform_type = "tanh"[training]
# Learning rate and schedule
epoch_multiplier = 15
high_lr = 6e-04
wd = 8e-04
lr_multiplier = 0.15
# Training parameters
test_train_ratio = 0.0
multiscatter_frac = 0.1
noise_frac = 0.0
batch_size = 256
max_patience = 30
val_prop = 0.2
random_seed = 42
enable_sharding = true # Enable data-parallel sharding across multiple devices
# Loader type: 'h5' or 'simulation'
loader_type = "h5"[data]
# Data loading configuration (for h5 loader)
data_dirs = ["data"]
file_pattern = "*.hdf5"
random_sample = false
sample_size = 50
exclude_runs = []
# PMT configuration
disabled_pmts = []
randomize_off_pmt = false
randomize_off_pmt_max = 50[output]
output_dir = "models"
output_name_suffix = ""[simulation]
# Simulation parameters (for simulation loader)
n_patterns = 1000000
tpc_radius = 66.4
secondary_yield = 350.0
spe_resolution = 0.05
lce_params = [0.01179, 2.39099, 10.3367, -6.77e-7, 9.86e-5]
pmt_positions_file = "PMT_POSITIONS.csv"
electron_min = 1
electron_max = 2000This module supports data-parallel training across multiple devices (GPUs or TPUs) for accelerated training. When enabled, the model parameters are replicated across all devices, and the training data is sharded across devices. Each device processes a portion of the data in parallel, resulting in faster training times.
To enable multi-device training, set the enable_sharding parameter to true in your configuration file:
[training]
# Other training parameters...
enable_sharding = trueData-parallel sharding is automatically enabled when multiple devices are available and the enable_sharding flag is set to true. The number of devices is automatically detected using jax.devices().
- Modular architecture with clear separation of concerns
- Extensible data loader system with a base
Loaderclass - Support for both real (HDF5) and simulated data
- Normalizing flow models for probabilistic position reconstruction
- Support for multiscatter generation and PMT disabling
- Comprehensive configuration system with TOML and CLI support
- Data-parallel sharding across multiple devices for accelerated training