Skip to content

Support non-numpy array backends#886

Open
ColmTalbot wants to merge 117 commits into
bilby-dev:mainfrom
ColmTalbot:bilback
Open

Support non-numpy array backends#886
ColmTalbot wants to merge 117 commits into
bilby-dev:mainfrom
ColmTalbot:bilback

Conversation

@ColmTalbot
Copy link
Copy Markdown
Collaborator

@ColmTalbot ColmTalbot commented Jan 7, 2025

I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.

The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.

The general guiding principles are:

  • when possible maintain existing behaviour with numpy/builtin arguments
  • work introspectively so users don't need to specify the target backend, but use input types
  • write as little backend specific code as possible, mostly through using the array-api specification and scipy interoperability

The primary changes so far are:

  • making most priors backend independent, there are a few holdouts where the underlying scipy functionality isn't compatible yet
  • core likelihoods mostly work with data from any backend
  • GW likelihoods work with any backend supported by the source function
  • the GW detector objects don't work via introspection, they need to be manually set
  • GW geometry (currently in bilby_cython) is handled via multiple-dispatch and added back into bilby

Changed behaviour:

Remaining issues:

  • Saving/loading nun-numpy arrays in result files may not work
  • I added some additional parameter conversions that I will remove
  • the bilby.gw.jaxstuff file should be removed and relevant functionality be moved elsewhere, it's currently just used for testing
  • the ROQ likelihood hasn't been ported
  • add more testing with JAX
  • translate some of the hyperparameter functionality, c.f., GWPopulation

@ColmTalbot ColmTalbot added the enhancement New feature or request label Jan 7, 2025
@ColmTalbot ColmTalbot marked this pull request as draft January 7, 2025 19:38
@ColmTalbot ColmTalbot force-pushed the bilback branch 2 times, most recently from ea348fa to 771a8a9 Compare January 22, 2026 17:00
@ColmTalbot ColmTalbot marked this pull request as ready for review January 23, 2026 15:24
@ColmTalbot ColmTalbot changed the title DRAFT: Support non-numpy array backends Support non-numpy array backends Jan 23, 2026
@ColmTalbot ColmTalbot added >100 lines refactoring to discuss To be discussed on an upcoming call labels Jan 23, 2026
@ColmTalbot
Copy link
Copy Markdown
Collaborator Author

This is now ready for review.
There are some things that won't work with JAX at the moment, e.g., various combinations of likelihood marginalization/acceleration.
I think we should accept this at the moment, for at least a bilby v3 alpha/beta release, and keep chipping away at the various subcases over time.

There are a lot of changes, but most of them are essentially np -> xp.
Some things required refactoring to avoid modifying slices of arrays as JAX doesn't like that.

Bilby can once again be installed without bilby.cython.
This should improve our general portability, but when bilby_cython is installed it will be used.

I've managed to keep test changes minimal:

  • I updated the joint prior test to make it more stringent (keys more randomly ordered).
  • I refactored some expensive prior initialization that was dramatically slowing things down.
  • I improved the logic for figuring out when ROQs are available to help my local testing.
  • Some mocks of numpy had to be updated.

@mj-will mj-will added this to the 3.0.0 milestone Jan 27, 2026
Copy link
Copy Markdown
Collaborator

@mj-will mj-will left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments but I'll need to have another look.

Comment thread bilby/compat/patches.py Outdated
Comment thread bilby/compat/utils.py
Comment thread bilby/compat/utils.py Outdated
Comment thread bilby/compat/utils.py Outdated
import os

import numpy as np
os.environ["SCIPY_ARRAY_API"] = "1" # noqa # flag for scipy backend switching
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry slightly about having this hard coded. Does it introduce more overhead when using just numpy?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. When we get close to merging I'll take this out.

Copy link
Copy Markdown
Collaborator

@GregoryAshton GregoryAshton May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this now need to be taken out? (Reading the mattermost it seems the priority is to get this merged)

This maps to the inverse CDF. This has been analytically solved for this case.
"""
return gammaincinv(self.k, val) * self.theta
return xp.asarray(gammaincinv(self.k, val)) * self.theta
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean this is falling back to numpy?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should update/recheck this, but at least jax doesn't have good support for this, but it looks like tensorflow has a version that numpyro uses (jax-ml/jax#5350). cupy does have this function, so this workaround may have just been for jax. I could add a BackendNotImplementedError.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a candidate for a small patch that uses the TF version for jax until jax supports it natively?

Comment thread bilby/core/prior/analytical.py Outdated
)
)

betaln,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this is.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not anything good.

Suggested change
betaln,

Comment thread bilby/core/prior/dict.py
Comment on lines +852 to +853
# return self.check_ln_prob(sample, ln_prob,
# normalized=normalized)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the removal of this intentional?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.

Comment thread bilby/core/prior/dict.py
Comment on lines -877 to -902
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
joint[self[key].dist.distname] = [key]
elif isinstance(self[key], JointPrior):
joint[self[key].dist.distname].append(key)
for names in joint.values():
# this is needed to unpack how joint prior rescaling works
# as an example of a joint prior over {a, b, c, d} we might
# get the following based on the order within the joint prior
# {a: [], b: [], c: [1, 2, 3, 4], d: []}
# -> [1, 2, 3, 4]
# -> {a: 1, b: 2, c: 3, d: 4}
values = list()
for key in names:
values = np.concatenate([values, result[key]])
for key, value in zip(names, values):
result[key] = value

def safe_flatten(value):
"""
this is gross but can be removed whenever we switch to returning
arrays, flatten converts 0-d arrays to 1-d so has to be special
cased
"""
if isinstance(value, (float, int)):
return value
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is removing this intentional?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.

Comment thread bilby/gw/utils.py Outdated
Comment on lines +250 to +251
# delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex
# theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest we remove this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex
# theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)

Copy link
Copy Markdown
Collaborator Author

@ColmTalbot ColmTalbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the initial comments @mj-will I'll take a pass at them ASAP.

Comment thread bilby/core/prior/analytical.py Outdated
)
)

betaln,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not anything good.

Suggested change
betaln,

Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/dict.py
Comment on lines +852 to +853
# return self.check_ln_prob(sample, ln_prob,
# normalized=normalized)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.

Comment thread bilby/core/prior/dict.py
Comment on lines -877 to -902
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
joint[self[key].dist.distname] = [key]
elif isinstance(self[key], JointPrior):
joint[self[key].dist.distname].append(key)
for names in joint.values():
# this is needed to unpack how joint prior rescaling works
# as an example of a joint prior over {a, b, c, d} we might
# get the following based on the order within the joint prior
# {a: [], b: [], c: [1, 2, 3, 4], d: []}
# -> [1, 2, 3, 4]
# -> {a: 1, b: 2, c: 3, d: 4}
values = list()
for key in names:
values = np.concatenate([values, result[key]])
for key, value in zip(names, values):
result[key] = value

def safe_flatten(value):
"""
this is gross but can be removed whenever we switch to returning
arrays, flatten converts 0-d arrays to 1-d so has to be special
cased
"""
if isinstance(value, (float, int)):
return value
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.

Comment thread bilby/gw/utils.py Outdated
Comment on lines +250 to +251
# delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex
# theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex
# theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)

Comment thread bilby/gw/utils.py Outdated
The natural logarithm of the bessel function
"""
return np.log(i0e(value)) + np.abs(value)
xp = array_module(value)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment to self: use xp_wrap here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be actioned?

Copy link
Copy Markdown
Collaborator

@GregoryAshton GregoryAshton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I got through about 60% of the diff and I'm pausing here so will submit the questions so far.

Comment thread bilby/compat/patches.py
Comment thread bilby/compat/patches.py Outdated
from .utils import BackendNotImplementedError


def erfinv_import(xp):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these functions would benefit from a docstring to explain they do the import given the type of array backend.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (for the one remaining function)

Comment thread bilby/core/prior/analytical.py
Comment thread bilby/core/prior/analytical.py
_cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu)
return _cdf
with np.errstate(divide="ignore"):
return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay - are the bounds being implemented here? But, I don't see the upper bound being implemented.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is carried over from the existing implementation.

Comment thread bilby/core/likelihood.py
Comment thread bilby/gw/detector/interferometer.py Outdated

signal[mode] = waveform_polarizations[mode] * det_response
signal_ifo = sum(signal.values()) * mask
signal[mode] = waveform_polarizations[mode] * mask * det_response
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is changing the way the mask is being used. From operating on a view to operating on the full array but zeroing the False cases. Is that correct?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it isn't, and I've just moved this multiplication by the mask up by a line.
I'm not sure why, but I don't think it should make a big difference.

@mj-will mj-will removed the to discuss To be discussed on an upcoming call label May 14, 2026
@ColmTalbot
Copy link
Copy Markdown
Collaborator Author

Python 3.10 doesn't have support for a vmappable version of logsumexp through scipy leading to this job failing (https://github.com/bilby-dev/bilby/actions/runs/25883935510/job/76070707573?pr=886).

How do people feel about dropping support for Python 3.10 in Bilby 3.10? Numpy dropped support about a year ago.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants