diff --git a/README.md b/README.md index 29af903..2344e0c 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ MRI reconstruction playground for the MRI Metrics project. | `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Applies an axis-aligned rectangular low-pass mask with separate cutoffs along `kx` and `ky`. | | `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Applies a circular low-pass mask with a raised-cosine transition band to soften the cutoff. | | `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Applies a circular low-pass mask with a Kaiser transition band for adjustable cutoff smoothness. | +| `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | | `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | | `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | | `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index f14107e..989b58a 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -31,17 +31,18 @@ # "tv-pdhg", ] DISTORTIONS = [ - "Phase-encode ghosting", - "Segmented translation motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Gaussian noise", - "Isotropic LP", + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + "Radial high-pass emphasis", + # "Gaussian noise", + # "Isotropic LP", ] METRICS = [ "PSNR", @@ -152,6 +153,8 @@ def choose_distortion(name: str) -> BaseDistortion: transition_fraction=0.4, beta=8.6, ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) case "Isotropic LP": return IsotropicResolutionReduction(radius_fraction=0.1) case "Off-center anisotropic Gaussian bias field": diff --git a/mri_recon/distortions/__init__.py b/mri_recon/distortions/__init__.py index 3af28ce..03159f3 100644 --- a/mri_recon/distortions/__init__.py +++ b/mri_recon/distortions/__init__.py @@ -16,4 +16,5 @@ HannTaperResolutionReduction, IsotropicResolutionReduction, KaiserTaperResolutionReduction, + RadialHighPassEmphasisDistortion, ) diff --git a/mri_recon/distortions/resolution.py b/mri_recon/distortions/resolution.py index cc38b86..1011abc 100644 --- a/mri_recon/distortions/resolution.py +++ b/mri_recon/distortions/resolution.py @@ -216,3 +216,51 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: profile="kaiser", beta=self.beta, ) + + +class RadialHighPassEmphasisDistortion(SelfAdjointMultiplicativeMaskDistortion): + """Radially boost high frequencies with a smooth monotone gain field. + + The mask equals ``1`` in the low-frequency core, rises smoothly across a + fixed transition band, and reaches ``1 + alpha`` at the sampled edge. This + behaves like a gentle high-frequency shelf rather than amplifying all + nonzero frequencies. + + :param float alpha: Non-negative gain added at the k-space edge. + :param float boost_start_radius: Normalized radius in ``[0, 1)`` where the + high-frequency shelf begins to rise. + :param float boost_end_radius: Normalized radius in ``(0, 1]`` where the + shelf reaches its full gain. + """ + + BOOST_START_RADIUS = 0.4 + BOOST_END_RADIUS = 0.9 + + def __init__( + self, + alpha: float = 0.4, + boost_start_radius: float = BOOST_START_RADIUS, + boost_end_radius: float = BOOST_END_RADIUS, + ) -> None: + super().__init__() + if alpha < 0.0: + raise ValueError("alpha must be non-negative") + if not 0.0 <= boost_start_radius < 1.0: + raise ValueError("boost_start_radius must be in [0, 1)") + if not 0.0 < boost_end_radius <= 1.0: + raise ValueError("boost_end_radius must be in (0, 1]") + if boost_start_radius >= boost_end_radius: + raise ValueError("boost_start_radius must be smaller than boost_end_radius") + + self.alpha = alpha + self.boost_start_radius = boost_start_radius + self.boost_end_radius = boost_end_radius + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + radius = _radial_frequency(shape).to(device) + transition = (radius - self.boost_start_radius) / ( + self.boost_end_radius - self.boost_start_radius + ) + transition = transition.clamp(0.0, 1.0) + transition = transition * transition * (3.0 - 2.0 * transition) + return 1.0 + self.alpha * transition diff --git a/tests/test_distortions.py b/tests/test_distortions.py index 5e6e971..06d09fb 100644 --- a/tests/test_distortions.py +++ b/tests/test_distortions.py @@ -18,6 +18,7 @@ KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, RotationalMotionDistortion, SegmentedTranslationMotionDistortion, SelfAdjointMultiplicativeMaskDistortion, @@ -30,6 +31,7 @@ "Anisotropic LP", "Hann taper LP", "Kaiser taper LP", + "Radial high-pass emphasis", "Gaussian bias field", "Off-center anisotropic Gaussian bias field", "Phase-encode ghosting", @@ -75,6 +77,8 @@ def choose_distortion(name): transition_fraction=0.25, beta=8.6, ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) case "Gaussian bias field": return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) case "Off-center anisotropic Gaussian bias field": @@ -249,6 +253,44 @@ def test_kaiser_taper_resolution_reduction_zero_transition_matches_hard_cutoff(d assert torch.equal(smooth.A(y), hard.A(y)) +def test_radial_high_pass_emphasis_distortion_boosts_edges_more_than_center(device): + distortion = RadialHighPassEmphasisDistortion(alpha=0.4) + shape = (1, 2, 33, 33) + center_y = shape[-2] // 2 + center_x = shape[-1] // 2 + + mask = distortion._mask(shape, torch.device(device)) + + assert mask[center_y, center_x] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) + assert torch.all(mask >= 1.0) + assert mask[center_y, center_x + 4] == pytest.approx(1.0) + assert torch.any((mask > 1.0) & (mask < 1.0 + distortion.alpha)) + + +def test_radial_high_pass_emphasis_distortion_zero_alpha_is_identity(device): + distortion = RadialHighPassEmphasisDistortion(alpha=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + assert torch.equal(distortion.A(y), y) + + +def test_radial_high_pass_emphasis_distortion_respects_custom_band(device): + distortion = RadialHighPassEmphasisDistortion( + alpha=0.4, + boost_start_radius=0.7, + boost_end_radius=0.95, + ) + shape = (1, 2, 65, 65) + center_y = shape[-2] // 2 + center_x = shape[-1] // 2 + + mask = distortion._mask(shape, torch.device(device)) + + assert mask[center_y, center_x + 12] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) + + def test_centered_isotropic_bias_matches_anisotropic_special_case(device): centered = GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) anisotropic = OffCenterAnisotropicGaussianKspaceBiasField( @@ -555,6 +597,7 @@ def test_segmented_translation_motion_keeps_zero_motion_segment_and_modulates_sh AnisotropicResolutionReduction, HannTaperResolutionReduction, KaiserTaperResolutionReduction, + RadialHighPassEmphasisDistortion, ], ) def test_resolution_reduction_classes_inherit_from_self_adjoint_multiplicative_mask(