Skip to content

RiceAstroparticleLab/probabilistic-posrec

Repository files navigation

Tests

Probabilistic Position Reconstruction

A machine learning-based system for probabilistic position reconstruction using normalizing flows to model particle positions in physics experiments.

The module automatically:

  1. Loads the specified configuration
  2. Loads data with an extensible data loader system
  3. Preprocesses the data for training
  4. Constructs and trains a position reconstruction model
  5. Saves the trained model to the specified output directory

Installation

pip install -e .

Usage

Command Line Usage

The module can be run directly from the command line:

python -m probabilistic_posrec --config config_file.toml --data-dir /path/to/data --output-name-suffix test_run

The primary way to use this module is to load a configuration file. This works better with the workflow of large experiments as configurations are likely expected to be rather stable and sharable. An example configuration file can be found in config.toml in the root directory of this repository.

However, it is possible to override the configuration file with the following command line arguments:

--config                Path to the TOML configuration file (required)
--loader-type           Type of data loader to use: h5 or simulation
--data-dir              Directories containing data files (comma-separated)
--file-pattern          Pattern to match data files
--random-sample         Load a random sample of files instead of all files
--sample-size           Number of files to randomly sample
--exclude-runs          Comma-separated list of run numbers to exclude
--output-dir            Directory to save trained models
--randomize-off-pmt     Randomly turn off PMTs
--output-name-suffix    Suffix string for output file names
--pmt-positions-file    Path to CSV file with PMT positions (for simulation loader)

Data Loaders

The system supports two types of data loaders:

1. HDF5 File Loader (default)

Loads position reconstruction data from HDF5 files. This is the default loader and is suitable for simulated data from more complex simulations.

2. Simulation Loader

Generates simulated S2 patterns using a simplified light collection efficiency model. This is useful for development, testing, and evaluations.

To use the simulation loader:

python -m probabilistic_posrec --config config_file.toml --loader-type simulation

The simulation loader requires a CSV file with PMT positions, which can be specified in the configuration file or via the --pmt-positions-file command line argument.

Coordinate Transformations

The system supports two transformation types for unbounding the bounded position coordinates (x, y) to unbounded latent space for the normalizing flow:

1. Hyperbolic Tangent Transform (tanh, default)

Uses arctanh to map bounded position coordinates to unbounded space. This is the original implementation and is suitable for most use cases.

[model]
transform_type = "tanh"

2. Unit Ball Transform (unit_ball)

Uses a chi-squared based bijection (StandardNormalToUnitBall) that maps a standard normal distribution to uniform distribution on the unit ball. This transformation:

  • Computes the chi-squared CDF of the squared magnitude ||x||²
  • Transforms to the correct radial distribution via the power transformation
  • Provides a probabilistically rigorous mapping between Gaussian and unit ball geometries

This transform requires the numerax package for computing the inverse incomplete gamma function (gammap_inverse), which is used in the inverse chi-squared CDF calculation.

[model]
transform_type = "unit_ball"

Note: Multi-variable inference (n_dims > 2) only transforms the first 2 dimensions (positions). Additional target variables are passed through unchanged as they are assumed to already be unbounded.

Creating a Custom Data Loader

The Loader class is designed to be extended for custom data loading. Here's an example of how to create a custom loader:

from probabilistic_posrec import Loader
from typing import Tuple
from jaxtyping import ArrayLike

class CustomLoader(Loader):
    def load(self) -> None:
        # Load data from your source
        pass
        
    def _get_hitpatterns_and_positions(self) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
        # Return (hitpatterns, x_positions, y_positions)
        pass
        
    def _prepare_hitpatterns(self, hitpatterns: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
        # Return (processed_hitpatterns, areas)
        pass
        
    def preprocess(self) -> None:
        # Use common preprocessing workflow
        self.preprocess_common()
        
    def get_data_array(self) -> Tuple[ArrayLike, ArrayLike]:
        # Return (train_x, train_cond)
        return self.train_x, self.train_cond

Configuration

Configuration is managed via TOML files with several sections:

Model Configuration

[model]
# Normalizing flow parameters
flow_layers = 5
nn_width = 128
nn_depth = 3
activation = "leaky_relu"
spline_knots = 5
spline_interval = 5

# Data transformation parameters
r_max = 66.4
buffer = 20.0
scale = 1.0
eps = 1e-7
log_area_scale = 10.0

# Transform type: "tanh" (default) or "unit_ball"
transform_type = "tanh"

Training Configuration

[training]
# Learning rate and schedule
epoch_multiplier = 15
high_lr = 6e-04
wd = 8e-04
lr_multiplier = 0.15

# Training parameters
test_train_ratio = 0.0
multiscatter_frac = 0.1
noise_frac = 0.0
batch_size = 256
max_patience = 30
val_prop = 0.2
random_seed = 42
enable_sharding = true  # Enable data-parallel sharding across multiple devices

# Loader type: 'h5' or 'simulation'
loader_type = "h5"

Data Configuration

[data]
# Data loading configuration (for h5 loader)
data_dirs = ["data"]
file_pattern = "*.hdf5"
random_sample = false
sample_size = 50
exclude_runs = []

# PMT configuration
disabled_pmts = []
randomize_off_pmt = false
randomize_off_pmt_max = 50

Output Configuration

[output]
output_dir = "models"
output_name_suffix = ""

Simulation Configuration

[simulation]
# Simulation parameters (for simulation loader)
n_patterns = 1000000
tpc_radius = 66.4
secondary_yield = 350.0
spe_resolution = 0.05
lce_params = [0.01179, 2.39099, 10.3367, -6.77e-7, 9.86e-5]
pmt_positions_file = "PMT_POSITIONS.csv"
electron_min = 1
electron_max = 2000

Multi-Device Training with Data Parallelism

This module supports data-parallel training across multiple devices (GPUs or TPUs) for accelerated training. When enabled, the model parameters are replicated across all devices, and the training data is sharded across devices. Each device processes a portion of the data in parallel, resulting in faster training times.

To enable multi-device training, set the enable_sharding parameter to true in your configuration file:

[training]
# Other training parameters...
enable_sharding = true

Data-parallel sharding is automatically enabled when multiple devices are available and the enable_sharding flag is set to true. The number of devices is automatically detected using jax.devices().

Key Features

  • Modular architecture with clear separation of concerns
  • Extensible data loader system with a base Loader class
  • Support for both real (HDF5) and simulated data
  • Normalizing flow models for probabilistic position reconstruction
  • Support for multiscatter generation and PMT disabling
  • Comprehensive configuration system with TOML and CLI support
  • Data-parallel sharding across multiple devices for accelerated training

About

Code and data for probabilistic position reconstruction paper

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors