Support non-numpy array backends#886
Conversation
ea348fa to
771a8a9
Compare
|
This is now ready for review. There are a lot of changes, but most of them are essentially Bilby can once again be installed without I've managed to keep test changes minimal:
|
mj-will
left a comment
There was a problem hiding this comment.
Some initial comments but I'll need to have another look.
| import os | ||
|
|
||
| import numpy as np | ||
| os.environ["SCIPY_ARRAY_API"] = "1" # noqa # flag for scipy backend switching |
There was a problem hiding this comment.
I worry slightly about having this hard coded. Does it introduce more overhead when using just numpy?
There was a problem hiding this comment.
I agree. When we get close to merging I'll take this out.
There was a problem hiding this comment.
Does this now need to be taken out? (Reading the mattermost it seems the priority is to get this merged)
| This maps to the inverse CDF. This has been analytically solved for this case. | ||
| """ | ||
| return gammaincinv(self.k, val) * self.theta | ||
| return xp.asarray(gammaincinv(self.k, val)) * self.theta |
There was a problem hiding this comment.
Does this mean this is falling back to numpy?
There was a problem hiding this comment.
Yeah, I should update/recheck this, but at least jax doesn't have good support for this, but it looks like tensorflow has a version that numpyro uses (jax-ml/jax#5350). cupy does have this function, so this workaround may have just been for jax. I could add a BackendNotImplementedError.
There was a problem hiding this comment.
Would this be a candidate for a small patch that uses the TF version for jax until jax supports it natively?
| ) | ||
| ) | ||
|
|
||
| betaln, |
There was a problem hiding this comment.
Not anything good.
| betaln, |
| # return self.check_ln_prob(sample, ln_prob, | ||
| # normalized=normalized) |
There was a problem hiding this comment.
Is the removal of this intentional?
There was a problem hiding this comment.
I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Is removing this intentional?
There was a problem hiding this comment.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | ||
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
There was a problem hiding this comment.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | |
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
ColmTalbot
left a comment
There was a problem hiding this comment.
Thanks for the initial comments @mj-will I'll take a pass at them ASAP.
| ) | ||
| ) | ||
|
|
||
| betaln, |
There was a problem hiding this comment.
Not anything good.
| betaln, |
| # return self.check_ln_prob(sample, ln_prob, | ||
| # normalized=normalized) |
There was a problem hiding this comment.
I'm fairly sure it was, but I'll double check. I think check_ln_prob was problematic in some way.
| self[key].least_recently_sampled = result[key] | ||
| if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: | ||
| joint[self[key].dist.distname] = [key] | ||
| elif isinstance(self[key], JointPrior): | ||
| joint[self[key].dist.distname].append(key) | ||
| for names in joint.values(): | ||
| # this is needed to unpack how joint prior rescaling works | ||
| # as an example of a joint prior over {a, b, c, d} we might | ||
| # get the following based on the order within the joint prior | ||
| # {a: [], b: [], c: [1, 2, 3, 4], d: []} | ||
| # -> [1, 2, 3, 4] | ||
| # -> {a: 1, b: 2, c: 3, d: 4} | ||
| values = list() | ||
| for key in names: | ||
| values = np.concatenate([values, result[key]]) | ||
| for key, value in zip(names, values): | ||
| result[key] = value | ||
|
|
||
| def safe_flatten(value): | ||
| """ | ||
| this is gross but can be removed whenever we switch to returning | ||
| arrays, flatten converts 0-d arrays to 1-d so has to be special | ||
| cased | ||
| """ | ||
| if isinstance(value, (float, int)): | ||
| return value |
There was a problem hiding this comment.
Yeah, this is in line with one of the other open PRs to update this logic. I'll dig it out in my next pass.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | ||
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
There was a problem hiding this comment.
| # delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex | |
| # theta, phi = zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x) |
| The natural logarithm of the bessel function | ||
| """ | ||
| return np.log(i0e(value)) + np.abs(value) | ||
| xp = array_module(value) |
There was a problem hiding this comment.
Comment to self: use xp_wrap here.
There was a problem hiding this comment.
Does this need to be actioned?
GregoryAshton
left a comment
There was a problem hiding this comment.
Okay, I got through about 60% of the diff and I'm pausing here so will submit the questions so far.
| from .utils import BackendNotImplementedError | ||
|
|
||
|
|
||
| def erfinv_import(xp): |
There was a problem hiding this comment.
All of these functions would benefit from a docstring to explain they do the import given the type of array backend.
There was a problem hiding this comment.
Done (for the one remaining function)
| _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) | ||
| return _cdf | ||
| with np.errstate(divide="ignore"): | ||
| return -val / self.mu - xp.log(xp.asarray(self.mu)) + xp.log(val >= self.minimum) |
There was a problem hiding this comment.
Ah okay - are the bounds being implemented here? But, I don't see the upper bound being implemented.
There was a problem hiding this comment.
I think this is carried over from the existing implementation.
|
|
||
| signal[mode] = waveform_polarizations[mode] * det_response | ||
| signal_ifo = sum(signal.values()) * mask | ||
| signal[mode] = waveform_polarizations[mode] * mask * det_response |
There was a problem hiding this comment.
It looks like this is changing the way the mask is being used. From operating on a view to operating on the full array but zeroing the False cases. Is that correct?
There was a problem hiding this comment.
I think it isn't, and I've just moved this multiplication by the mask up by a line.
I'm not sure why, but I don't think it should make a big difference.
|
Python 3.10 doesn't have support for a vmappable version of logsumexp through scipy leading to this job failing (https://github.com/bilby-dev/bilby/actions/runs/25883935510/job/76070707573?pr=886). How do people feel about dropping support for Python 3.10 in Bilby 3.10? Numpy dropped support about a year ago. |
This required making some changes to the tests for conditional dicts as I've changed the output types and the backend introspection doesn't work on dict_items for some reason
I've been working on this PR on and off for a few months, it isn't ready yet, but I wanted to share it in case other people had early opinions.
The goal is to make it easier to interface with models/samplers implemented in e.g., JAX, that support GPU/TPU acceleration and JIT compilation.
The general guiding principles are:
array-apispecification andscipyinteroperabilityThe primary changes so far are:
Changed behaviour:
Remaining issues:
bilby.gw.jaxstufffile should be removed and relevant functionality be moved elsewhere, it's currently just used for testing