From b540bfa92b83b2104c5961df9ee4b0915357cee9 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Tue, 15 Apr 2025 17:58:55 -0400 Subject: [PATCH 01/57] histogram rebase to dev --- src/spac/visualization.py | 167 ++++++++++++--------- tests/test_visualization/test_histogram.py | 40 ++++- 2 files changed, 136 insertions(+), 71 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9003c163..8dacf1ec 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -405,7 +405,7 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, - x_log_scale=False, y_log_scale=False, **kwargs): + x_log_scale=False, y_log_scale=False, facet=False, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. @@ -447,6 +447,9 @@ def histogram(adata, feature=None, annotation=None, layer=None, y_log_scale : bool, default False If True, the y-axis will be set to log scale. + facet : bool, default False + If True, group by function outputs facet plots + **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: @@ -658,7 +661,6 @@ def calculate_histogram(data, bins, bin_edges=None): kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer @@ -666,61 +668,82 @@ def calculate_histogram(data, bins, bin_edges=None): ax.set_title(f'Layer: {layer}') axs.append(ax) else: - fig, ax_array = plt.subplots( - n_groups, 1, figsize=(5, 5 * n_groups) - ) + if not facet: + fig, ax_array = plt.subplots( + n_groups, 1, figsize=(5, 5 * n_groups) + ) - # Convert a single Axes object to a list - # Ensure ax_array is always iterable - if n_groups == 1: - ax_array = [ax_array] - else: - ax_array = ax_array.flatten() - - for i, ax_i in enumerate(ax_array): - group_data = plot_data[plot_data[group_by] == - groups[i]][data_column] - hist_data = calculate_histogram(group_data, kwargs['bins']) - - sns.histplot(data=hist_data, x="bin_center", ax=ax_i, - weights='count', **kwargs) - # If plotting feature specify which layer - if feature: - ax_i.set_title(f'{groups[i]} with Layer: {layer}') + # Convert a single Axes object to a list + # Ensure ax_array is always iterable + if n_groups == 1: + ax_array = [ax_array] else: - ax_i.set_title(f'{groups[i]}') + ax_array = ax_array.flatten() + + for i, ax_i in enumerate(ax_array): + group_data = plot_data[plot_data[group_by] == + groups[i]][data_column] + hist_data = calculate_histogram(group_data, kwargs['bins']) + + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, + weights='count', **kwargs) + # If plotting feature specify which layer + if feature: + ax_i.set_title(f'{groups[i]} with Layer: {layer}') + else: + ax_array = ax_array.flatten() + + # Set axis scales if y_log_scale is True + if y_log_scale: + ax_i.set_yscale('log') + + # Adjust x-axis label if x_log_scale is True + if x_log_scale: + xlabel = f'log({data_column})' + else: + xlabel = data_column + ax_i.set_xlabel(xlabel) + + # Adjust y-axis label based on 'stat' parameter + stat = kwargs.get('stat', 'count') + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability' + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + ax_i.set_ylabel(ylabel) + + axs.append(ax_i) + else: + hist = sns.FacetGrid(plot_data, col=group_by) + # Map the histogram function to the grid + hist.map(sns.histplot, data_column, **kwargs) - # Set axis scales if y_log_scale is True - if y_log_scale: - ax_i.set_yscale('log') + # Set rotation of label + hist.set_xticklabels(rotation=20, ha='right') + + # Titles for each facet + hist.set_titles("{col_name}") + + # Ajust top margin + hist.figure.subplots_adjust(left=.1, + top=0.85, + bottom=0.15, + hspace=0.3) + + fig = hist.figure + axs.extend(hist.axes.flat) - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column - ax_i.set_xlabel(xlabel) - - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - ax_i.set_ylabel(ylabel) - - axs.append(ax_i) else: # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): ax.set_xlim(hist_data['bin_left'].min(), - hist_data['bin_right'].max()) + hist_data['bin_right'].max()) sns.histplot( data=hist_data, @@ -735,35 +758,39 @@ def calculate_histogram(data, bins, bin_edges=None): ax.set_title(f'Layer: {layer}') axs.append(ax) - # Set axis scales if y_log_scale is True - if y_log_scale: - ax.set_yscale('log') + axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] + for ax in axes: + # Set axis scales if y_log_scale is True + if y_log_scale: + ax.set_yscale('log') - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column - ax.set_xlabel(xlabel) - - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - ax.set_ylabel(ylabel) + # Adjust x-axis label if x_log_scale is True + if x_log_scale: + xlabel = f'log({data_column})' + else: + xlabel = data_column + ax.set_xlabel(xlabel) + + # Adjust y-axis label based on 'stat' parameter + stat = kwargs.get('stat', 'count') + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability' + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + ax.set_ylabel(ylabel) + ax.tick_params(axis='x', rotation=90, labelsize=10) if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} else: return {"fig": fig, "axs": axs, "df": hist_data} + def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index f8ba95ea..999b66e5 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -222,7 +222,7 @@ def test_y_log_scale_axis(self): def test_y_log_scale_label(self): """Test that y-axis label is updated when y_log_scale is True.""" - fig, ax, dfd = histogram( + fig, ax, df = histogram( self.adata, feature='marker1', y_log_scale=True @@ -413,6 +413,44 @@ def test_default_bins_calculation(self): expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) + def test_facet_plot(self): + """Test that facet plot works.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ).values() + + # Check if axs is a collection (list/array of Axes) + self.assertIsInstance(ax, (list, np.ndarray), + "Output is not a multi-axis grid") + + # Check number of facets equals number of unique groups + unique_groups = self.adata.obs['annotation2'].dropna().unique() + self.assertEqual(len(ax), len(unique_groups), + f"Expected {len(unique_groups)}" + f" facet plots, got {len(ax)}.") + + # Validate each axis: title, xlabel, and ylabel + for i, axis in enumerate(ax): + # Check that title is set and matches the group + title = axis.get_title() + self.assertTrue(title, f"Facet {i} is missing a title.") + self.assertTrue(any(str(group) in title + for group in unique_groups), + f"Title '{title}' does not contain" + f"any expected group names.") + + # Check X and Y labels + self.assertIn('marker1', axis.get_xlabel(), + f"Facet {i} X-axis label" + f" '{axis.get_xlabel()}' is incorrect.") + self.assertIn(axis.get_ylabel(), + ['Count', 'Frequency', 'Density', 'Probability'], + f"Facet {i} Y-axis label" + f" '{axis.get_ylabel()}' is not a valid stat.") + if __name__ == '__main__': unittest.main() From 21779cda9338e7a2c0bb9d0fe1e7f3f65a8dc9bb Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Tue, 1 Apr 2025 22:40:08 -0400 Subject: [PATCH 02/57] add facet plots on histogram group by --- src/spac/visualization.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 8dacf1ec..090716bf 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -570,7 +570,8 @@ def cal_bin_num( ): bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') - return(bins) + + return (bins) num_rows = plot_data.shape[0] @@ -632,6 +633,7 @@ def calculate_histogram(data, bins, bin_edges=None): if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) + if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") @@ -667,6 +669,7 @@ def calculate_histogram(data, bins, bin_edges=None): if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) + else: if not facet: fig, ax_array = plt.subplots( @@ -716,7 +719,6 @@ def calculate_histogram(data, bins, bin_edges=None): if y_log_scale: ylabel = f'log({ylabel})' ax_i.set_ylabel(ylabel) - axs.append(ax_i) else: hist = sns.FacetGrid(plot_data, col=group_by) From 09a9012159be7a4fd62cd3a275dd04386469f318 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Tue, 1 Apr 2025 22:50:35 -0400 Subject: [PATCH 03/57] add unittest for facet output --- tests/test_visualization/test_histogram.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 999b66e5..7256078d 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -412,6 +412,18 @@ def test_default_bins_calculation(self): # Using 2 * (n ** 1/3) heuristic for default bins expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) + + def test_facet_plot(self): + """Test that facet plot works.""" + fig, ax = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ) + + # Check if axs is a collection (list/array of Axes) + self.assertIsInstance(ax, (list, np.ndarray), "Output is not a multi-axis grid") def test_facet_plot(self): """Test that facet plot works.""" From b1693aa5be9decb0f5a13457abfc206bcc09abc5 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Wed, 2 Apr 2025 12:46:07 -0400 Subject: [PATCH 04/57] formatting of facet plot, and unittest --- src/spac/visualization.py | 5 +++++ tests/test_visualization/test_histogram.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 090716bf..838178c6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -688,8 +688,13 @@ def calculate_histogram(data, bins, bin_edges=None): groups[i]][data_column] hist_data = calculate_histogram(group_data, kwargs['bins']) +<<<<<<< HEAD sns.histplot(data=hist_data, x="bin_center", ax=ax_i, weights='count', **kwargs) +======= + sns.histplot(data=group_data, x=data_column, + ax=ax_i, **kwargs) +>>>>>>> 11127cf (formatting of facet plot, and unittest) # If plotting feature specify which layer if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 7256078d..bf35dc27 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -423,7 +423,8 @@ def test_facet_plot(self): ) # Check if axs is a collection (list/array of Axes) - self.assertIsInstance(ax, (list, np.ndarray), "Output is not a multi-axis grid") + self.assertIsInstance(ax, (list, np.ndarray), + "Output is not a multi-axis grid") def test_facet_plot(self): """Test that facet plot works.""" From a9af69508045dc9115439871fb99c90e59c7f516 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Wed, 2 Apr 2025 12:47:18 -0400 Subject: [PATCH 05/57] formatting of facet plot, and unittest --- src/spac/visualization.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 838178c6..090716bf 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -688,13 +688,8 @@ def calculate_histogram(data, bins, bin_edges=None): groups[i]][data_column] hist_data = calculate_histogram(group_data, kwargs['bins']) -<<<<<<< HEAD sns.histplot(data=hist_data, x="bin_center", ax=ax_i, weights='count', **kwargs) -======= - sns.histplot(data=group_data, x=data_column, - ax=ax_i, **kwargs) ->>>>>>> 11127cf (formatting of facet plot, and unittest) # If plotting feature specify which layer if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') From b5b5699d7b63e484a18b56373b42092a87582050 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Tue, 8 Apr 2025 15:52:20 -0400 Subject: [PATCH 06/57] remove axis rotation and formatting --- src/spac/visualization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 090716bf..475957f8 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -785,7 +785,6 @@ def calculate_histogram(data, bins, bin_edges=None): if y_log_scale: ylabel = f'log({ylabel})' ax.set_ylabel(ylabel) - ax.tick_params(axis='x', rotation=90, labelsize=10) if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} From ff759aa27b29ae76a08c4e958c11c89b86029ae5 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Thu, 10 Apr 2025 23:44:22 -0400 Subject: [PATCH 07/57] Unittest addition of element numbers, title, and axis labels check --- tests/test_visualization/test_histogram.py | 31 +++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index bf35dc27..702d429a 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -412,7 +412,7 @@ def test_default_bins_calculation(self): # Using 2 * (n ** 1/3) heuristic for default bins expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) - + def test_facet_plot(self): """Test that facet plot works.""" fig, ax = histogram( @@ -421,10 +421,35 @@ def test_facet_plot(self): group_by='annotation2', facet=True, ) - + # Check if axs is a collection (list/array of Axes) self.assertIsInstance(ax, (list, np.ndarray), - "Output is not a multi-axis grid") + "Output is not a multi-axis grid") + + # Check number of facets equals number of unique groups + unique_groups = self.adata.obs['annotation2'].dropna().unique() + self.assertEqual(len(ax), len(unique_groups), + f"Expected {len(unique_groups)}" + f" facet plots, got {len(ax)}.") + + # Validate each axis: title, xlabel, and ylabel + for i, axis in enumerate(ax): + # Check that title is set and matches the group + title = axis.get_title() + self.assertTrue(title, f"Facet {i} is missing a title.") + self.assertTrue(any(str(group) in title + for group in unique_groups), + f"Title '{title}' does not contain" + f"any expected group names.") + + # Check X and Y labels + self.assertIn('marker1', axis.get_xlabel(), + f"Facet {i} X-axis label" + f" '{axis.get_xlabel()}' is incorrect.") + self.assertIn(axis.get_ylabel(), + ['Count', 'Frequency', 'Density', 'Probability'], + f"Facet {i} Y-axis label" + f" '{axis.get_ylabel()}' is not a valid stat.") def test_facet_plot(self): """Test that facet plot works.""" From 7019cae6239b05648aa58faabad09cb5786ec090 Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Tue, 15 Apr 2025 18:10:40 -0400 Subject: [PATCH 08/57] unittest rebase with dev --- tests/test_visualization/test_histogram.py | 38 ---------------------- 1 file changed, 38 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 702d429a..999b66e5 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -413,44 +413,6 @@ def test_default_bins_calculation(self): expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) - def test_facet_plot(self): - """Test that facet plot works.""" - fig, ax = histogram( - self.adata, - feature='marker1', - group_by='annotation2', - facet=True, - ) - - # Check if axs is a collection (list/array of Axes) - self.assertIsInstance(ax, (list, np.ndarray), - "Output is not a multi-axis grid") - - # Check number of facets equals number of unique groups - unique_groups = self.adata.obs['annotation2'].dropna().unique() - self.assertEqual(len(ax), len(unique_groups), - f"Expected {len(unique_groups)}" - f" facet plots, got {len(ax)}.") - - # Validate each axis: title, xlabel, and ylabel - for i, axis in enumerate(ax): - # Check that title is set and matches the group - title = axis.get_title() - self.assertTrue(title, f"Facet {i} is missing a title.") - self.assertTrue(any(str(group) in title - for group in unique_groups), - f"Title '{title}' does not contain" - f"any expected group names.") - - # Check X and Y labels - self.assertIn('marker1', axis.get_xlabel(), - f"Facet {i} X-axis label" - f" '{axis.get_xlabel()}' is incorrect.") - self.assertIn(axis.get_ylabel(), - ['Count', 'Frequency', 'Density', 'Probability'], - f"Facet {i} Y-axis label" - f" '{axis.get_ylabel()}' is not a valid stat.") - def test_facet_plot(self): """Test that facet plot works.""" fig, ax, df = histogram( From 44b35666f9da471f43d4e7832b2b18a23407343a Mon Sep 17 00:00:00 2001 From: ying39purdue Date: Wed, 16 Apr 2025 09:39:55 -0400 Subject: [PATCH 09/57] correct facet function return --- src/spac/visualization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 475957f8..8498eaef 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -739,6 +739,7 @@ def calculate_histogram(data, bins, bin_edges=None): fig = hist.figure axs.extend(hist.axes.flat) + hist_data = plot_data else: # Precompute histogram data for single plot From a0f3fcc305a54d967167410d37d0bcdf3e176580 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 24 Mar 2026 19:44:42 -0400 Subject: [PATCH 10/57] feat(histogram): enhance faceted layout customization - Add customizable FacetGrid parameters: facet_ncol, facet_vertical_threshold, facet_height, and facet_aspect - Implement automatic grid layout that switches between vertical (<=4 groups) and compact grid layout (>4 groups) - Improve visual styling with background color and grid lines on facets - Fix axis label handling for faceted plots using supxlabel/supylabel - Adjust margins and spacing for better readability across different layouts - Update docstring with new FacetGrid customization parameters --- src/spac/visualization.py | 91 ++++++++++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 8498eaef..9fc8a9b3 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -483,6 +483,16 @@ def histogram(adata, feature=None, annotation=None, layer=None, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. If not provided, the binning will be determined automatically. Note, don't pass a numpy array, only python lists or strs/numbers. + When `facet=True`, these optional keys can be passed via `kwargs` + to customize the FacetGrid layout: + - `facet_ncol`: int or None, number of facet columns. + If None, the function uses one column for small group counts and + switches to a compact grid for many groups. + - `facet_vertical_threshold`: int, max number of groups that should + stay in a vertical single-column layout when `facet_ncol` is None. + Default is 4. + - `facet_height`: float, facet height in inches. Default is 3.2. + - `facet_aspect`: float, facet width/height ratio. Default is 1.25. Returns ------- @@ -720,22 +730,70 @@ def calculate_histogram(data, bins, bin_edges=None): ylabel = f'log({ylabel})' ax_i.set_ylabel(ylabel) axs.append(ax_i) - else: - hist = sns.FacetGrid(plot_data, col=group_by) + + else: # Facet option + # Set default values for facet parameters if not provided in kwargs + facet_ncol = kwargs.get('facet_ncol', None) + facet_vertical_threshold = kwargs.get( + 'facet_vertical_threshold', 4 + ) + facet_height = kwargs.get('facet_height', 3.2) + facet_aspect = kwargs.get('facet_aspect', 1.25) + + # Default: vertical layout for a few groups, grid for many. + if facet_ncol is None: + if n_groups <= facet_vertical_threshold: + facet_ncol = 1 + else: + facet_ncol = int(np.ceil(np.sqrt(n_groups))) + + facet_ncol = max(1, min(int(facet_ncol), n_groups)) + + # Create the FacetGrid for the histogram + hist = sns.FacetGrid( + plot_data, + col=group_by, + col_wrap=facet_ncol, + height=facet_height, + aspect=facet_aspect, + sharex=True, + sharey=True + ) + + # Remove facet-specific keys from kwargs to avoid passing them to histplot + facet_only_keys = { + 'facet_ncol', + 'facet_vertical_threshold', + 'facet_height', + 'facet_aspect', + } + hist_kwargs = { + k: v for k, v in kwargs.items() + if k not in facet_only_keys + } + # Map the histogram function to the grid - hist.map(sns.histplot, data_column, **kwargs) + hist.map_dataframe(sns.histplot, x=data_column, **hist_kwargs) + + # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column + for ax_i in hist.axes.flat: + ax_i.tick_params(axis='x', labelbottom=True) + ax_i.tick_params(axis='y', labelleft=True) - # Set rotation of label - hist.set_xticklabels(rotation=20, ha='right') + # Set background color and grid for better readability + for ax_i in hist.axes.flat: + ax_i.set_facecolor('#f2f2f2') + ax_i.grid(True, which='major', axis='both') # Titles for each facet hist.set_titles("{col_name}") - # Ajust top margin + # Adjust margins for readability across layouts. hist.figure.subplots_adjust(left=.1, - top=0.85, - bottom=0.15, - hspace=0.3) + top=0.9, + bottom=0.12, + hspace=0.35, + wspace=0.2) fig = hist.figure axs.extend(hist.axes.flat) @@ -772,7 +830,10 @@ def calculate_histogram(data, bins, bin_edges=None): xlabel = f'log({data_column})' else: xlabel = data_column - ax.set_xlabel(xlabel) + if facet: + ax.set_xlabel('') + else: + ax.set_xlabel(xlabel) # Adjust y-axis label based on 'stat' parameter stat = kwargs.get('stat', 'count') @@ -785,7 +846,15 @@ def calculate_histogram(data, bins, bin_edges=None): ylabel = ylabel_map.get(stat, 'Count') if y_log_scale: ylabel = f'log({ylabel})' - ax.set_ylabel(ylabel) + if facet: + ax.set_ylabel('') + else: + ax.set_ylabel(ylabel) + + # Set a common x and y label for the entire figure if facet is True + if facet and fig is not None: + fig.supxlabel(xlabel) + fig.supylabel(ylabel) if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} From bc2e95e7f116db9fe97c5034243bae786d2ee449 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 24 Mar 2026 21:08:34 -0400 Subject: [PATCH 11/57] fix(histogram): enforce shared bins across facets --- src/spac/visualization.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9fc8a9b3..3dd2584a 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -749,6 +749,14 @@ def calculate_histogram(data, bins, bin_edges=None): facet_ncol = max(1, min(int(facet_ncol), n_groups)) + # Compute global bins so all facets use consistent boundaries. + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + global_bin_edges = np.histogram_bin_edges( + plot_data[data_column], bins=kwargs['bins'] + ) + else: + global_bin_edges = plot_data[data_column].unique() + # Create the FacetGrid for the histogram hist = sns.FacetGrid( plot_data, @@ -771,6 +779,7 @@ def calculate_histogram(data, bins, bin_edges=None): k: v for k, v in kwargs.items() if k not in facet_only_keys } + hist_kwargs['bins'] = global_bin_edges # Map the histogram function to the grid hist.map_dataframe(sns.histplot, x=data_column, **hist_kwargs) From 674b84ed083aa62027e5e6c2bae2603c044a2414 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 24 Mar 2026 23:16:25 -0400 Subject: [PATCH 12/57] fix(histogram): convert numpy array to list for seaborn bins compatibility --- src/spac/visualization.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 3dd2584a..694bf444 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -779,7 +779,11 @@ def calculate_histogram(data, bins, bin_edges=None): k: v for k, v in kwargs.items() if k not in facet_only_keys } - hist_kwargs['bins'] = global_bin_edges + # For numeric data, pass global bin edges to ensure consistent binning across facets. + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + hist_kwargs['bins'] = global_bin_edges.tolist() + else: + hist_kwargs.pop('bins', None) # Map the histogram function to the grid hist.map_dataframe(sns.histplot, x=data_column, **hist_kwargs) From 814998fb7bb6107f86e166b7385e6a04ed35e3fa Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 24 Mar 2026 23:18:42 -0400 Subject: [PATCH 13/57] docs(histogram): add TODO comments for binning logic review - Add TODO comments documenting the need to review binning logic in the histogram function. Notes indicate potential double-binning behavior that may need refactoring to pass global_bin_edges to seaborn directly. - Also adds clarifying comment about figure and axes output. --- src/spac/visualization.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 694bf444..a3da49d6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -673,6 +673,13 @@ def calculate_histogram(data, bins, bin_edges=None): kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") + ''' + TODO: Recheck the binning logic. + I think we may need to pass the global_bin_edges to seaborn. + I think the current implementation is actually doing a 'double-binning', + which may not be desirable. + ''' + sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer @@ -808,6 +815,7 @@ def calculate_histogram(data, bins, bin_edges=None): hspace=0.35, wspace=0.2) + # Pass the figure and axes to the output for further customization fig = hist.figure axs.extend(hist.axes.flat) hist_data = plot_data From 6b69eb514029e09384f3706ec4a97c09f3f02625 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 26 Mar 2026 00:20:13 -0400 Subject: [PATCH 14/57] fix(histogram): close unused internal figure --- src/spac/visualization.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index a3da49d6..fa64caf6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -560,10 +560,14 @@ def histogram(adata, feature=None, annotation=None, layer=None, else: df[data_column] = np.log1p(df[data_column]) + # If ax is not provided, create a new figure and axes. + # Keep track of whether we created the figure internally + created_internal_fig = False if ax is not None: fig = ax.get_figure() else: fig, ax = plt.subplots() + created_internal_fig = True axs = [] @@ -688,6 +692,11 @@ def calculate_histogram(data, bins, bin_edges=None): axs.append(ax) else: + # Only close figures created in this function. If caller provided + # an external ax, keep its parent figure open. + if created_internal_fig: + plt.close(fig) + if not facet: fig, ax_array = plt.subplots( n_groups, 1, figsize=(5, 5 * n_groups) From 60eb8506bdde1a4362a464eeefc6e5073556133d Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sun, 29 Mar 2026 20:19:50 -0400 Subject: [PATCH 15/57] refactor(histogram): extract bin edge computation into _compute_global_bin_edges - Add new helper function _compute_global_bin_edges() to consolidate consistent bin-edge calculation for numeric and categorical data - Replace duplicate bin-edge logic in together and facet cases - Improves code maintainability and reduces duplication in histogram workflow --- src/spac/visualization.py | 52 ++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index fa64caf6..b69f4eb7 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -403,6 +403,40 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): return fig, ax + +def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: + """ + Compute consistent bin edges across all data for aligned histograms/facets. + + This helper ensures that when creating multiple histograms (e.g., in + together mode or facet plots), all subplots use the same bin boundaries + for proper visual comparison. + + Parameters + ---------- + data_series : pd.Series + The data to compute bin edges for. + bins : int or sequence + Number of bins (for numeric data) or bin specification. + + Returns + ------- + array-like + Bin edges for numeric data (array of boundary values), + or unique categories for categorical data (array of category labels). + + Notes + ----- + For numeric data, uses numpy's histogram_bin_edges to compute consistent + boundaries. For categorical data, returns all unique categories present + in the data. + """ + if pd.api.types.is_numeric_dtype(data_series): + return np.histogram_bin_edges(data_series, bins=bins) + else: + return data_series.unique() + + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, x_log_scale=False, y_log_scale=False, facet=False, **kwargs): @@ -654,12 +688,9 @@ def calculate_histogram(data, bins, bin_edges=None): if together: # Compute global bin edges based on the entire dataset - if pd.api.types.is_numeric_dtype(plot_data[data_column]): - global_bin_edges = np.histogram_bin_edges( - plot_data[data_column], bins=kwargs['bins'] - ) - else: - global_bin_edges = plot_data[data_column].unique() + global_bin_edges = _compute_global_bin_edges( + plot_data[data_column], kwargs['bins'] + ) hist_data = [] # Compute histograms for each group separately and combine them @@ -766,12 +797,9 @@ def calculate_histogram(data, bins, bin_edges=None): facet_ncol = max(1, min(int(facet_ncol), n_groups)) # Compute global bins so all facets use consistent boundaries. - if pd.api.types.is_numeric_dtype(plot_data[data_column]): - global_bin_edges = np.histogram_bin_edges( - plot_data[data_column], bins=kwargs['bins'] - ) - else: - global_bin_edges = plot_data[data_column].unique() + global_bin_edges = _compute_global_bin_edges( + plot_data[data_column], kwargs['bins'] + ) # Create the FacetGrid for the histogram hist = sns.FacetGrid( From 248fb459cbee4bc2591fff0c2c1fb8a892cef637 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sun, 29 Mar 2026 20:41:51 -0400 Subject: [PATCH 16/57] fix(histogram): revert group titles for facet plots without feature specification --- src/spac/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index b69f4eb7..bd4eb378 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -751,7 +751,7 @@ def calculate_histogram(data, bins, bin_edges=None): if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') else: - ax_array = ax_array.flatten() + ax_i.set_title(f'{groups[i]}') # Set axis scales if y_log_scale is True if y_log_scale: From 16e3dc035805269856f480ff3cbc7252a4d4cb79 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 1 Apr 2026 01:00:15 -0400 Subject: [PATCH 17/57] refactor(histogram): modularize facet layout logic - extract helpers for internal layout kwarg parsing, facet geometry derivation, and axis label resolution - simplify histogram flow by removing duplicated label logic and centralizing kwargs sanitization - keep facet API boundary clean by documenting only facet_ncol as user-facing and target_fig_width/target_fig_height as internal hints --- src/spac/visualization.py | 180 ++++++++++++++++++++++---------------- 1 file changed, 103 insertions(+), 77 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index bd4eb378..d6598fbc 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -437,6 +437,83 @@ def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: return data_series.unique() +def _parse_histogram_layout_kwargs(kwargs): + """Extract histogram-internal layout hints and strip non-histplot keys.""" + facet_ncol = kwargs.pop('facet_ncol', None) + target_fig_width = kwargs.pop('target_fig_width', None) + target_fig_height = kwargs.pop('target_fig_height', None) + + return facet_ncol, target_fig_width, target_fig_height + + +def _derive_facet_geometry( + n_groups, + facet_ncol, + target_fig_width, + target_fig_height, + vertical_threshold=4, + default_height=3.2, + default_aspect=1.25, + min_panel_width=1.8, + min_panel_height=1.6, + min_aspect=0.6, + max_aspect=2.0, +): + """Derive FacetGrid column count and panel geometry.""" + if facet_ncol is not None: + try: + facet_ncol = int(facet_ncol) + except (TypeError, ValueError): + facet_ncol = None + + if facet_ncol is None: + if n_groups <= vertical_threshold: + facet_ncol = 1 + else: + facet_ncol = int(np.ceil(np.sqrt(n_groups))) + + facet_ncol = max(1, min(int(facet_ncol), n_groups)) + facet_height = default_height + facet_aspect = default_aspect + + if target_fig_width is not None and target_fig_height is not None: + try: + target_fig_width = float(target_fig_width) + target_fig_height = float(target_fig_height) + except (TypeError, ValueError): + target_fig_width = None + target_fig_height = None + + if ( + target_fig_width is not None and + target_fig_height is not None and + target_fig_width > 0 and + target_fig_height > 0 + ): + nrow = int(np.ceil(n_groups / facet_ncol)) + panel_width = max(target_fig_width / facet_ncol, min_panel_width) + panel_height = max(target_fig_height / nrow, min_panel_height) + facet_height = panel_height + facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) + + return facet_ncol, facet_height, facet_aspect + + +def _resolve_histogram_axis_labels(data_column, x_log_scale, y_log_scale, stat): + """Resolve histogram axis labels from scaling and stat settings.""" + xlabel = f'log({data_column})' if x_log_scale else data_column + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability' + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + return xlabel, ylabel + + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, x_log_scale=False, y_log_scale=False, facet=False, **kwargs): @@ -517,16 +594,14 @@ def histogram(adata, feature=None, annotation=None, layer=None, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. If not provided, the binning will be determined automatically. Note, don't pass a numpy array, only python lists or strs/numbers. - When `facet=True`, these optional keys can be passed via `kwargs` - to customize the FacetGrid layout: + When `facet=True`, this optional key can be passed via `kwargs` + to customize FacetGrid layout: - `facet_ncol`: int or None, number of facet columns. If None, the function uses one column for small group counts and switches to a compact grid for many groups. - - `facet_vertical_threshold`: int, max number of groups that should - stay in a vertical single-column layout when `facet_ncol` is None. - Default is 4. - - `facet_height`: float, facet height in inches. Default is 3.2. - - `facet_aspect`: float, facet width/height ratio. Default is 1.25. + Internal-only sizing hints used by template wrappers: + - `target_fig_width`: float, intended final figure width in inches. + - `target_fig_height`: float, intended final figure height in inches. Returns ------- @@ -628,6 +703,12 @@ def cal_bin_num( if 'bins' not in kwargs: kwargs['bins'] = cal_bin_num(num_rows) + # Parse histogram-internal layout kwargs and remove them from kwargs + # so they never leak to seaborn's histplot calls. + facet_ncol, target_fig_width, target_fig_height = ( + _parse_histogram_layout_kwargs(kwargs) + ) + # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): """ @@ -752,49 +833,15 @@ def calculate_histogram(data, bins, bin_edges=None): ax_i.set_title(f'{groups[i]} with Layer: {layer}') else: ax_i.set_title(f'{groups[i]}') - - # Set axis scales if y_log_scale is True - if y_log_scale: - ax_i.set_yscale('log') - - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column - ax_i.set_xlabel(xlabel) - - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - ax_i.set_ylabel(ylabel) axs.append(ax_i) else: # Facet option - # Set default values for facet parameters if not provided in kwargs - facet_ncol = kwargs.get('facet_ncol', None) - facet_vertical_threshold = kwargs.get( - 'facet_vertical_threshold', 4 + facet_ncol, facet_height, facet_aspect = _derive_facet_geometry( + n_groups=n_groups, + facet_ncol=facet_ncol, + target_fig_width=target_fig_width, + target_fig_height=target_fig_height, ) - facet_height = kwargs.get('facet_height', 3.2) - facet_aspect = kwargs.get('facet_aspect', 1.25) - - # Default: vertical layout for a few groups, grid for many. - if facet_ncol is None: - if n_groups <= facet_vertical_threshold: - facet_ncol = 1 - else: - facet_ncol = int(np.ceil(np.sqrt(n_groups))) - - facet_ncol = max(1, min(int(facet_ncol), n_groups)) # Compute global bins so all facets use consistent boundaries. global_bin_edges = _compute_global_bin_edges( @@ -812,17 +859,7 @@ def calculate_histogram(data, bins, bin_edges=None): sharey=True ) - # Remove facet-specific keys from kwargs to avoid passing them to histplot - facet_only_keys = { - 'facet_ncol', - 'facet_vertical_threshold', - 'facet_height', - 'facet_aspect', - } - hist_kwargs = { - k: v for k, v in kwargs.items() - if k not in facet_only_keys - } + hist_kwargs = kwargs.copy() # For numeric data, pass global bin edges to ensure consistent binning across facets. if pd.api.types.is_numeric_dtype(plot_data[data_column]): hist_kwargs['bins'] = global_bin_edges.tolist() @@ -877,36 +914,25 @@ def calculate_histogram(data, bins, bin_edges=None): ax.set_title(f'Layer: {layer}') axs.append(ax) + stat = kwargs.get('stat', 'count') + xlabel, ylabel = _resolve_histogram_axis_labels( + data_column=data_column, + x_log_scale=x_log_scale, + y_log_scale=y_log_scale, + stat=stat, + ) + axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] for ax in axes: # Set axis scales if y_log_scale is True if y_log_scale: ax.set_yscale('log') - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column if facet: ax.set_xlabel('') - else: - ax.set_xlabel(xlabel) - - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - if facet: ax.set_ylabel('') else: + ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) # Set a common x and y label for the entire figure if facet is True From a2112c298c49bc069951fa0b7147b8c95d4dc0b7 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 3 Apr 2026 15:00:17 -0400 Subject: [PATCH 18/57] feat(histogram): wire facet layout from template - add Facet and Facet_Ncol parsing and validation in template - normalize facet_ncol handling in histogram layout kwargs - adjust facet_ncol hints accordingly in histogram docstrings --- src/spac/templates/histogram_template.py | 26 +++++++++++++++++++- src/spac/visualization.py | 31 +++++++++++++++++++++--- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 0a3924d4..07e82e47 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -113,6 +113,8 @@ def run_from_json( stat = params.get("Stat", "count") x_rotate = params.get("X_Axis_Label_Rotation", 0) histplot_by = params.get("Plot_By", "Annotation") + facet = params.get("Facet", False) + facet_ncol = params.get("Facet_Ncol", "auto") # Close all existing figures to prevent extra plots plt.close('all') @@ -186,6 +188,24 @@ def run_from_json( f'Received "{x_rotate}".' ) + # Validate facet_ncol, allowing for "auto" or positive integers + facet_ncol = text_to_value( + facet_ncol, + default_none_text="auto", + value_to_convert_to="auto" + ) + if facet_ncol != "auto": + facet_ncol = text_to_value( + facet_ncol, + to_int=True, + param_name="Facet_Ncol" + ) + if facet_ncol <= 0: + raise ValueError( + f'Facet_Ncol must be a positive integer or "auto". ' + f'Received "{facet_ncol}".' + ) + # Initialize the x-variable before the loop if histplot_by == "Annotation": x_var = annotation @@ -202,11 +222,15 @@ def run_from_json( ax=None, x_log_scale=take_X_log, y_log_scale=take_Y_log, + facet=facet, multiple=multiple, shrink=shrink, bins=bins, alpha=alpha, - stat=stat + stat=stat, + facet_ncol=facet_ncol, + target_fig_width=fig_width, + target_fig_height=fig_height, ) fig = result["fig"] diff --git a/src/spac/visualization.py b/src/spac/visualization.py index d6598fbc..7af5f844 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -438,11 +438,35 @@ def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: def _parse_histogram_layout_kwargs(kwargs): - """Extract histogram-internal layout hints and strip non-histplot keys.""" + """Extract histogram-internal layout hints and strip non-histplot keys. + + This parser is intentionally permissive for direct API usage: values that + cannot be interpreted are normalized to None so downstream auto-layout can + take over. + """ facet_ncol = kwargs.pop('facet_ncol', None) target_fig_width = kwargs.pop('target_fig_width', None) target_fig_height = kwargs.pop('target_fig_height', None) + # Normalize only; template-level validation handles strict checks. + if isinstance(facet_ncol, str): + facet_ncol_str = facet_ncol.strip().lower() + if facet_ncol_str in {'auto', 'none', ''}: + facet_ncol = None + else: + try: + facet_ncol = int(facet_ncol) + except ValueError: + facet_ncol = None + + if facet_ncol is not None: + try: + facet_ncol = int(facet_ncol) + except (TypeError, ValueError): + facet_ncol = None + if facet_ncol <= 0: + facet_ncol = None + return facet_ncol, target_fig_width, target_fig_height @@ -596,8 +620,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, Note, don't pass a numpy array, only python lists or strs/numbers. When `facet=True`, this optional key can be passed via `kwargs` to customize FacetGrid layout: - - `facet_ncol`: int or None, number of facet columns. - If None, the function uses one column for small group counts and + - `facet_ncol`: positive int or "auto", number of facet columns. + If "auto", the function uses one column for small group counts and switches to a compact grid for many groups. Internal-only sizing hints used by template wrappers: - `target_fig_width`: float, intended final figure width in inches. @@ -836,6 +860,7 @@ def calculate_histogram(data, bins, bin_edges=None): axs.append(ax_i) else: # Facet option + # Derive facet geometry based on group count and layout hints facet_ncol, facet_height, facet_aspect = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, From 3c0050bacd452aec11f322f36785eea37a8d85ff Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 3 Apr 2026 15:03:58 -0400 Subject: [PATCH 19/57] feat(histogram): add element control and style validation - expose Element in histogram template inputs - validate multiple, element, and stat values in template - document shrink and alpha keyword behavior in histogram docstring --- src/spac/templates/histogram_template.py | 29 ++++++++++++++++++++++++ src/spac/visualization.py | 4 ++++ 2 files changed, 33 insertions(+) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 07e82e47..a8243527 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -107,6 +107,7 @@ def run_from_json( take_X_log = params.get("Take_X_Log", False) take_Y_log = params.get("Take_Y_log", False) multiple = params.get("Multiple", "dodge") + element = params.get("Element", "bars") shrink = params.get("Shrink_Number", 1) bins = params.get("Bins", "auto") alpha = params.get("Bin_Transparency", 0.75) @@ -182,6 +183,33 @@ def run_from_json( "Setting bin number calculation to auto." ) + # Validate enum-like plotting controls after bins validation. + allowed_multiple = {"layer", "dodge", "stack", "fill"} + allowed_element = {"bars", "step", "poly"} + allowed_stat = { + "count", "frequency", "density", "probability", + "proportion", "percent" + } + multiple = str(multiple).strip().lower() + element = str(element).strip().lower() + stat = str(stat).strip().lower() + if multiple not in allowed_multiple: + raise ValueError( + f'Multiple must be one of {sorted(allowed_multiple)}. ' + f'Received "{multiple}".' + ) + if element not in allowed_element: + raise ValueError( + f'Element must be one of {sorted(allowed_element)}. ' + f'Received "{element}".' + ) + if stat not in allowed_stat: + raise ValueError( + f'Stat must be one of {sorted(allowed_stat)}. ' + f'Received "{stat}".' + ) + + # Validate x-axis label rotation if (x_rotate < 0) or (x_rotate > 360): raise ValueError( f'The X label rotation should fall within 0 to 360 degree. ' @@ -224,6 +252,7 @@ def run_from_json( y_log_scale=take_Y_log, facet=facet, multiple=multiple, + element=element, shrink=shrink, bins=bins, alpha=alpha, diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 7af5f844..fafd13a7 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -601,6 +601,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, * "step": Creates a step line plot without bars. * "poly": Creates a polygon where the bottom edge represents the x-axis and the top edge the histogram's bins. + - `shrink`: Scale bar width relative to each bin's width. + Useful for reducing overlap with `multiple="dodge"`. + - `alpha`: Transparency for histogram artists. + 0 is fully transparent and 1 is fully opaque. - `log_scale`: Determines if the data should be plotted on a logarithmic scale. - `stat`: Determines the statistical transformation to use on the data From b29ef58f417b2e9c648655255352737403a8f987 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 7 Apr 2026 01:30:18 -0400 Subject: [PATCH 20/57] fix(histogram): tighten facet validation and labels - force `multiple="dodge"` when `together=False` in histogram template - reject invalid `together=True` with `facet=True` combinations - add shared x-label and facet-aware title behavior for grouped facets - support `proportion` and `percent` y-axis labels in histogram plotting --- src/spac/templates/histogram_template.py | 30 ++++++++++++++++++++++-- src/spac/visualization.py | 4 +++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index a8243527..5a77d4a3 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -183,6 +183,14 @@ def run_from_json( "Setting bin number calculation to auto." ) + # Validate multiple parameter based on together + if together is False and multiple: + multiple = "dodge" + logger.warning( + "Multiple should not be used when Together is False. " + "Setting Multiple to 'dodge'." + ) + # Validate enum-like plotting controls after bins validation. allowed_multiple = {"layer", "dodge", "stack", "fill"} allowed_element = {"bars", "step", "poly"} @@ -216,6 +224,12 @@ def run_from_json( f'Received "{x_rotate}".' ) + # Validate that together and facet are not both True + if together and facet: + raise ValueError( + 'Together and Facet cannot both be True. Please set one to False.' + ) + # Validate facet_ncol, allowing for "auto" or positive integers facet_ncol = text_to_value( facet_ncol, @@ -302,8 +316,13 @@ def run_from_json( # Rotate x labels ax.tick_params(axis='x', rotation=x_rotate) + + # Process x-axis label for faceted plots + if facet: + x_label = f"log({x_var})" if take_X_log else x_var + fig.supxlabel(x_label, rotation=x_rotate) - # Set titles based on group_by + # Set titles based on group_by and facet if text_to_value(group_by): if together: for ax in axes: @@ -320,9 +339,16 @@ def run_from_json( "Number of axes does not match number of " "groups. Titles may not correspond correctly." ) + if facet: + fig.suptitle( + f'Histogram of "{x_var}" faceted by "{group_by}"' + ) + ax_title_prefix = f'Group' + else: + ax_title_prefix = f'Histogram of "{x_var}" for group' for ax, grp in zip(axes, unique_groups): ax.set_title( - f'Histogram of "{x_var}" for group: "{grp}"' + f'{ax_title_prefix}: "{grp}"' ) else: for ax in axes: diff --git a/src/spac/visualization.py b/src/spac/visualization.py index fafd13a7..28f6c5b1 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -530,7 +530,9 @@ def _resolve_histogram_axis_labels(data_column, x_log_scale, y_log_scale, stat): 'count': 'Count', 'frequency': 'Frequency', 'density': 'Density', - 'probability': 'Probability' + 'probability': 'Probability', + "proportion": "Proportion", + "percent": "Percent" } ylabel = ylabel_map.get(stat, 'Count') if y_log_scale: From 6991defdb87924c1adc262cc175ced5a753565ae Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 13 Apr 2026 23:38:47 -0400 Subject: [PATCH 21/57] chore(gitignore): ignore workspace directory --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8ec0749c..7f94f5ee 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ test/__pychache__ build/ *.egg-info/ +workspace/ \ No newline at end of file From 85c570a049ff3f9fe29b89d3311a8014f48de900 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 14 Apr 2026 01:23:54 -0400 Subject: [PATCH 22/57] fix(histogram): normalize bins defaults to Rice-rule estimator - Catch None, "auto", "none" in kwargs['bins'] and route to cal_bin_num to prevent slow seaborn double-binning. - Clarify docstring accordingly - Remove resolved TODO about double-binning logic --- src/spac/visualization.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 28f6c5b1..ed2952fd 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -622,7 +622,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, Can be a number (indicating the number of bins) or a list (indicating bin edges). For example, `bins=10` will create 10 bins, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. - If not provided, the binning will be determined automatically. + If not provided, or if passed as `None`/`"auto"`/`"none"`, + the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. When `facet=True`, this optional key can be passed via `kwargs` to customize FacetGrid layout: @@ -728,9 +729,20 @@ def cal_bin_num( num_rows = plot_data.shape[0] - # Check if bins is being passed + # Check if bins is being passed or set to None or "auto" in kwargs. # If not, the in house algorithm will compute the number of bins + bins_kwarg = kwargs.get('bins', None) + use_default_bins = False if 'bins' not in kwargs: + use_default_bins = True + elif bins_kwarg is None: + use_default_bins = True + elif isinstance(bins_kwarg, str): + bins_kwarg_norm = bins_kwarg.strip().lower() + if bins_kwarg_norm in {'', 'auto', 'none'}: + use_default_bins = True + + if use_default_bins: kwargs['bins'] = cal_bin_num(num_rows) # Parse histogram-internal layout kwargs and remove them from kwargs @@ -819,13 +831,6 @@ def calculate_histogram(data, bins, bin_edges=None): kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - ''' - TODO: Recheck the binning logic. - I think we may need to pass the global_bin_edges to seaborn. - I think the current implementation is actually doing a 'double-binning', - which may not be desirable. - ''' - sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer From e5e908266850c60c56128a8ebf285dc22efe6051 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 14 Apr 2026 11:39:08 -0400 Subject: [PATCH 23/57] refactor(histogram): simplify default bins fallback logic - remove redundant `'bins' not in kwargs` branch - tidy related inline comments and docstring wording --- src/spac/visualization.py | 38 ++++++++++++-------------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index ed2952fd..4c26876b 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -405,8 +405,8 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: - """ - Compute consistent bin edges across all data for aligned histograms/facets. + """Compute global bin edges for a data series based on its type and the + bins parameter. This helper ensures that when creating multiple histograms (e.g., in together mode or facet plots), all subplots use the same bin boundaries @@ -414,22 +414,14 @@ def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: Parameters ---------- - data_series : pd.Series - The data to compute bin edges for. - bins : int or sequence - Number of bins (for numeric data) or bin specification. + data_series (pd.Series): The data to compute bin edges for. + bins (int or sequence): Number of bins (for numeric data) or bin specification. Returns ------- - array-like + array-like: Bin edges for numeric data (array of boundary values), or unique categories for categorical data (array of category labels). - - Notes - ----- - For numeric data, uses numpy's histogram_bin_edges to compute consistent - boundaries. For categorical data, returns all unique categories present - in the data. """ if pd.api.types.is_numeric_dtype(data_series): return np.histogram_bin_edges(data_series, bins=bins) @@ -729,20 +721,14 @@ def cal_bin_num( num_rows = plot_data.shape[0] - # Check if bins is being passed or set to None or "auto" in kwargs. - # If not, the in house algorithm will compute the number of bins + # Check if bins is not being passed or set to None or "auto" in kwargs. + # If so, the in house algorithm will compute the number of bins bins_kwarg = kwargs.get('bins', None) - use_default_bins = False - if 'bins' not in kwargs: - use_default_bins = True - elif bins_kwarg is None: - use_default_bins = True - elif isinstance(bins_kwarg, str): - bins_kwarg_norm = bins_kwarg.strip().lower() - if bins_kwarg_norm in {'', 'auto', 'none'}: - use_default_bins = True - - if use_default_bins: + if isinstance(bins_kwarg, str): + bins_kwarg = bins_kwarg.strip().lower() + if bins_kwarg in {'', 'auto', 'none'}: + bins_kwarg = None + if bins_kwarg is None: kwargs['bins'] = cal_bin_num(num_rows) # Parse histogram-internal layout kwargs and remove them from kwargs From a3228f89c88f47e33a1cf1877ee28700b4711eeb Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 14 Apr 2026 19:21:57 -0400 Subject: [PATCH 24/57] feat(histogram): enhance facet validation and parameter naming - Add validation to ensure group_by is specified when facet=True - Ensure facet and together cannot both be True - Rename target_fig_width/height to facet_fig_width/height for clarity - Update histogram docstring to document facet sizing parameters - Ensure facet_fig_width/height are used in figure sizing --- src/spac/templates/histogram_template.py | 20 ++++++--- src/spac/visualization.py | 57 ++++++++++++++---------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 5a77d4a3..83b2234b 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -224,11 +224,17 @@ def run_from_json( f'Received "{x_rotate}".' ) - # Validate that together and facet are not both True - if together and facet: - raise ValueError( - 'Together and Facet cannot both be True. Please set one to False.' - ) + # Validate facet, group_by, and together parameters for logical consistency + if facet: + if group_by is None: + raise ValueError( + 'Facet is True but Group_by is not specified. ' + 'Please specify Group_by when using Facet.' + ) + if together: + raise ValueError( + 'Together and Facet cannot both be True. Please set one to False.' + ) # Validate facet_ncol, allowing for "auto" or positive integers facet_ncol = text_to_value( @@ -272,8 +278,8 @@ def run_from_json( alpha=alpha, stat=stat, facet_ncol=facet_ncol, - target_fig_width=fig_width, - target_fig_height=fig_height, + facet_fig_width=fig_width, + facet_fig_height=fig_height, ) fig = result["fig"] diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 4c26876b..ed082437 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -437,8 +437,8 @@ def _parse_histogram_layout_kwargs(kwargs): take over. """ facet_ncol = kwargs.pop('facet_ncol', None) - target_fig_width = kwargs.pop('target_fig_width', None) - target_fig_height = kwargs.pop('target_fig_height', None) + facet_fig_width = kwargs.pop('facet_fig_width', None) + facet_fig_height = kwargs.pop('facet_fig_height', None) # Normalize only; template-level validation handles strict checks. if isinstance(facet_ncol, str): @@ -459,14 +459,14 @@ def _parse_histogram_layout_kwargs(kwargs): if facet_ncol <= 0: facet_ncol = None - return facet_ncol, target_fig_width, target_fig_height + return facet_ncol, facet_fig_width, facet_fig_height def _derive_facet_geometry( n_groups, facet_ncol, - target_fig_width, - target_fig_height, + facet_fig_width, + facet_fig_height, vertical_threshold=4, default_height=3.2, default_aspect=1.25, @@ -492,23 +492,23 @@ def _derive_facet_geometry( facet_height = default_height facet_aspect = default_aspect - if target_fig_width is not None and target_fig_height is not None: + if facet_fig_width is not None and facet_fig_height is not None: try: - target_fig_width = float(target_fig_width) - target_fig_height = float(target_fig_height) + facet_fig_width = float(facet_fig_width) + facet_fig_height = float(facet_fig_height) except (TypeError, ValueError): - target_fig_width = None - target_fig_height = None + facet_fig_width = None + facet_fig_height = None if ( - target_fig_width is not None and - target_fig_height is not None and - target_fig_width > 0 and - target_fig_height > 0 + facet_fig_width is not None and + facet_fig_height is not None and + facet_fig_width > 0 and + facet_fig_height > 0 ): nrow = int(np.ceil(n_groups / facet_ncol)) - panel_width = max(target_fig_width / facet_ncol, min_panel_width) - panel_height = max(target_fig_height / nrow, min_panel_height) + panel_width = max(facet_fig_width / facet_ncol, min_panel_width) + panel_height = max(facet_fig_height / nrow, min_panel_height) facet_height = panel_height facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) @@ -617,14 +617,13 @@ def histogram(adata, feature=None, annotation=None, layer=None, If not provided, or if passed as `None`/`"auto"`/`"none"`, the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. - When `facet=True`, this optional key can be passed via `kwargs` + When `facet=True`, these optional key can be passed via `kwargs` to customize FacetGrid layout: - `facet_ncol`: positive int or "auto", number of facet columns. If "auto", the function uses one column for small group counts and switches to a compact grid for many groups. - Internal-only sizing hints used by template wrappers: - - `target_fig_width`: float, intended final figure width in inches. - - `target_fig_height`: float, intended final figure height in inches. + - `facet_fig_width`: float, intended final figure width in inches. + - `facet_fig_height`: float, intended final figure height in inches. Returns ------- @@ -731,9 +730,17 @@ def cal_bin_num( if bins_kwarg is None: kwargs['bins'] = cal_bin_num(num_rows) + # Input validation for facet + if facet: + if group_by is None: + raise ValueError("group_by must be specified when facet=True.") + if together: + raise ValueError("Cannot use together=True with facet=True," + " choose one.") + # Parse histogram-internal layout kwargs and remove them from kwargs # so they never leak to seaborn's histplot calls. - facet_ncol, target_fig_width, target_fig_height = ( + facet_ncol, facet_fig_width, facet_fig_height = ( _parse_histogram_layout_kwargs(kwargs) ) @@ -861,8 +868,8 @@ def calculate_histogram(data, bins, bin_edges=None): facet_ncol, facet_height, facet_aspect = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, - target_fig_width=target_fig_width, - target_fig_height=target_fig_height, + facet_fig_width=facet_fig_width, + facet_fig_height=facet_fig_height, ) # Compute global bins so all facets use consistent boundaries. @@ -878,7 +885,7 @@ def calculate_histogram(data, bins, bin_edges=None): height=facet_height, aspect=facet_aspect, sharex=True, - sharey=True + sharey=True, ) hist_kwargs = kwargs.copy() @@ -913,6 +920,8 @@ def calculate_histogram(data, bins, bin_edges=None): # Pass the figure and axes to the output for further customization fig = hist.figure + fig.set_size_inches(facet_fig_width or fig.get_figwidth(), + facet_fig_height or fig.get_figheight()) axs.extend(hist.axes.flat) hist_data = plot_data From b990b8c2a5af8d03c3963dddaaa872f7e12fa797 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 16 Apr 2026 00:40:55 -0400 Subject: [PATCH 25/57] chore(gitignore): remove workspace entry --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 7f94f5ee..528a7356 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,4 @@ .vscode/launch.json test/__pychache__ build/ -*.egg-info/ -workspace/ \ No newline at end of file +*.egg-info/ \ No newline at end of file From ff67a193f3e98f1fb6a8b0cff2546f4c405e88aa Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 16 Apr 2026 20:05:27 -0400 Subject: [PATCH 26/57] feat(histogram): improve ax parameter handling and validation - Reject external ax for grouped-separate and facet modes - Refactor lazy figure creation (only when needed) - Update docstring to document ax parameter constraints - Add tearDown to test cleanup - Add guardrail validation tests --- src/spac/visualization.py | 26 ++++++----- tests/test_visualization/test_histogram.py | 52 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index ed082437..12893883 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -568,6 +568,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. + Not supported for grouped-separate (`group_by` with `together=False`) + or facet layouts (`group_by` with `facet=True`). x_log_scale : bool, default False If True, the data will be transformed using np.log1p before plotting, @@ -691,14 +693,15 @@ def histogram(adata, feature=None, annotation=None, layer=None, else: df[data_column] = np.log1p(df[data_column]) - # If ax is not provided, create a new figure and axes. - # Keep track of whether we created the figure internally - created_internal_fig = False + # If ax is provided, validate input and get figure from it. + # If not, the figure will be created in the plotting branch. if ax is not None: + if group_by and not together: + raise ValueError( + "External ax is only supported for single-axes histogram " + "Please set together=True or remove external ax." + ) fig = ax.get_figure() - else: - fig, ax = plt.subplots() - created_internal_fig = True axs = [] @@ -803,6 +806,9 @@ def calculate_histogram(data, bins, bin_edges=None): " histogram.") if together: + if ax is None: + fig, ax = plt.subplots() + # Compute global bin edges based on the entire dataset global_bin_edges = _compute_global_bin_edges( plot_data[data_column], kwargs['bins'] @@ -832,11 +838,6 @@ def calculate_histogram(data, bins, bin_edges=None): axs.append(ax) else: - # Only close figures created in this function. If caller provided - # an external ax, keep its parent figure open. - if created_internal_fig: - plt.close(fig) - if not facet: fig, ax_array = plt.subplots( n_groups, 1, figsize=(5, 5 * n_groups) @@ -926,6 +927,9 @@ def calculate_histogram(data, bins, bin_edges=None): hist_data = plot_data else: + if ax is None: + fig, ax = plt.subplots() + # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 999b66e5..a0de1655 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -38,6 +38,10 @@ def setUp(self): # Create default layer self.adata.layers['Default'] = X.astype(np.float32) + def tearDown(self): + # Closes all figures to prevent memory issues + plt.close('all') + def test_both_feature_and_annotation(self): err_msg = ("Cannot pass both feature and annotation," " choose one.") @@ -353,6 +357,7 @@ def test_layer(self): ) def test_ax_passed_as_argument(self): + # Supported mode 1: single-axes histogram with external ax. fig, ax = plt.subplots() returned_fig, returned_ax, df = histogram( self.adata, @@ -365,8 +370,53 @@ def test_ax_passed_as_argument(self): # Check that the passed fig is the one that is returned self.assertIs(fig, returned_fig) - # Check that returned_ax is an Axes object + # Supported mode 2: grouped+together histogram with external ax. + fig_grouped, ax_grouped = plt.subplots() + returned_grouped_fig, returned_grouped_ax, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + ax=ax_grouped, + ).values() + + self.assertIs(fig_grouped, returned_grouped_fig) + self.assertIs(ax_grouped, returned_grouped_ax) + + # Check that returned axes are valid Axes objects. self.assertIsInstance(returned_ax, mpl.axes.Axes) + self.assertIsInstance(returned_grouped_ax, mpl.axes.Axes) + + def test_external_ax_guardrail_modes(self): + # Reject grouped-separate mode with external ax. + fig_1, ax_1 = plt.subplots() + with self.assertRaisesRegex( + ValueError, + "External ax is only supported for single-axes histogram" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=False, + ax=ax_1, + ) + + # Reject facet mode with external ax. + fig_2, ax_2 = plt.subplots() + with self.assertRaisesRegex( + ValueError, + "External ax is only supported for single-axes histogram" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ax=ax_2, + ) + + # Positive external-ax modes are covered in test_ax_passed_as_argument. def test_default_first_feature(self): with self.assertWarns(UserWarning) as warning: From 0fc8b32d4a9216b1adfda82fcba991cf0d1b3e67 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 16 Apr 2026 20:58:08 -0400 Subject: [PATCH 27/57] test(histogram): update facet label validation assertions Refactor label validation for facet mode: - Check that individual facet axes have empty labels - Verify figure-level supxlabel/supylabel are set correctly --- tests/test_visualization/test_histogram.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index a0de1655..5131b03b 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -492,14 +492,17 @@ def test_facet_plot(self): f"Title '{title}' does not contain" f"any expected group names.") - # Check X and Y labels - self.assertIn('marker1', axis.get_xlabel(), - f"Facet {i} X-axis label" - f" '{axis.get_xlabel()}' is incorrect.") - self.assertIn(axis.get_ylabel(), - ['Count', 'Frequency', 'Density', 'Probability'], - f"Facet {i} Y-axis label" - f" '{axis.get_ylabel()}' is not a valid stat.") + # In facet mode, labels are figure-level (supxlabel/supylabel). + self.assertEqual(axis.get_xlabel(), '', + f"Facet {i} x-label should be empty.") + self.assertEqual(axis.get_ylabel(), '', + f"Facet {i} y-label should be empty.") + + # Figure-level labels should be set in facet mode. + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'marker1') + self.assertEqual(fig._supylabel.get_text(), 'Count') if __name__ == '__main__': From ac5490ef4fd502047a79d6e5dc5229ba6cf52d1f Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 16 Apr 2026 22:56:52 -0400 Subject: [PATCH 28/57] test(histogram): refactor and expand facet plot test coverage - Split monolithic facet test into focused test methods - Add smoke test for structure and bar patch presence - Separate titles/label validation into dedicated test - Add density stat label test case - Improve test docstrings and assertions clarity --- tests/test_visualization/test_histogram.py | 52 ++++++++++++++++++---- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 5131b03b..622a08c0 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -463,8 +463,8 @@ def test_default_bins_calculation(self): expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) - def test_facet_plot(self): - """Test that facet plot works.""" + def test_facet_plot_smoke_and_structure(self): + """Facet path returns expected structure and plotted content.""" fig, ax, df = histogram( self.adata, feature='marker1', @@ -472,27 +472,47 @@ def test_facet_plot(self): facet=True, ).values() - # Check if axs is a collection (list/array of Axes) + # Basic structure checks + self.assertIsNotNone(fig) + self.assertIsNotNone(df) self.assertIsInstance(ax, (list, np.ndarray), - "Output is not a multi-axis grid") + "Facet output should be a multi-axis collection.") - # Check number of facets equals number of unique groups + # Check the number of facet axes matches group count unique_groups = self.adata.obs['annotation2'].dropna().unique() self.assertEqual(len(ax), len(unique_groups), f"Expected {len(unique_groups)}" f" facet plots, got {len(ax)}.") - # Validate each axis: title, xlabel, and ylabel + # Lightweight bar-level presence checks only. + for i, axis in enumerate(ax): + self.assertGreater( + len(axis.patches), + 0, + f"Facet {i} should contain at least one bar patch." + ) + + def test_facet_plot_titles_and_label_policy(self): + """Facet titles map to groups and labels follow figure-level policy.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ).values() + + # Ensure ax is iterable for consistent handling + ax = ax if isinstance(ax, (list, np.ndarray)) else [ax] + unique_groups = self.adata.obs['annotation2'].dropna().unique() + + # Titles must map to expected groups and labels must be per-figure. for i, axis in enumerate(ax): - # Check that title is set and matches the group title = axis.get_title() self.assertTrue(title, f"Facet {i} is missing a title.") self.assertTrue(any(str(group) in title for group in unique_groups), f"Title '{title}' does not contain" f"any expected group names.") - - # In facet mode, labels are figure-level (supxlabel/supylabel). self.assertEqual(axis.get_xlabel(), '', f"Facet {i} x-label should be empty.") self.assertEqual(axis.get_ylabel(), '', @@ -504,6 +524,20 @@ def test_facet_plot(self): self.assertEqual(fig._supxlabel.get_text(), 'marker1') self.assertEqual(fig._supylabel.get_text(), 'Count') + def test_facet_plot_density_stat_label_policy(self): + """Facet figure-level y label reflects non-default stat mapping.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + stat='density', + ).values() + + # Check that figure-level y label reflects 'density' stat when specified. + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supylabel.get_text(), 'Density') + if __name__ == '__main__': unittest.main() From eba316b559084916624a85bc611a36db62488720 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 17 Apr 2026 14:50:37 -0400 Subject: [PATCH 29/57] fix(histogram): preserve facet xlabel and apply rotation Refactor facet x-label rotation handling: - Apply rotation to existing label object - Add warning if figure-level label not found - Prevents label override --- src/spac/templates/histogram_template.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 83b2234b..b7c544ec 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -323,10 +323,16 @@ def run_from_json( # Rotate x labels ax.tick_params(axis='x', rotation=x_rotate) - # Process x-axis label for faceted plots + # Process figure-level xlabel for faceted plots if facet: - x_label = f"log({x_var})" if take_X_log else x_var - fig.supxlabel(x_label, rotation=x_rotate) + facet_xlabel = getattr(fig, '_supxlabel', None) + if facet_xlabel is None: + logger.warning( + "Facet xlabel not found. X label rotation will " + "not be applied." + ) + else: + facet_xlabel.set_rotation(x_rotate) # Set titles based on group_by and facet if text_to_value(group_by): From 1e738a2deb5debd7068de394ab920f0c5674cb6a Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 17 Apr 2026 14:51:33 -0400 Subject: [PATCH 30/57] test(histogram): add facet mode parameters to template test --- tests/templates/test_histogram_template.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/templates/test_histogram_template.py b/tests/templates/test_histogram_template.py index 5a8e49e8..426e386f 100644 --- a/tests/templates/test_histogram_template.py +++ b/tests/templates/test_histogram_template.py @@ -50,6 +50,8 @@ def setUp(self) -> None: params = { "Upstream_Analysis": self.in_file, "Annotation": "cell_type", + "Group_by": "cell_type", + "Together": False, "Table_to_Visualize": "Original", "Feature_s_to_Plot": ["All"], "Figure_Title": "Test Histogram", @@ -59,6 +61,8 @@ def setUp(self) -> None: "Figure_DPI": 72, "Font_Size": 10, "Number_of_Bins": 20, + "Facet": True, + "Facet_Ncol": 1, "Output_Directory": self.tmp_dir.name, "outputs": { "dataframe": {"type": "file", "name": "dataframe.csv"}, From 685d4f6780798ee67557c9d86143563a5370bfa6 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 17 Apr 2026 21:07:09 -0400 Subject: [PATCH 31/57] test(histogram): add facet validation and categorical tests - Add test for facet mode requiring group_by parameter - Add test for facet mode conflicting with together=True - Add test for facet mode with categorical annotations --- tests/test_visualization/test_histogram.py | 50 ++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 622a08c0..2b419d22 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -463,6 +463,32 @@ def test_default_bins_calculation(self): expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) self.assertEqual(n_bins, expected_bins) + def test_facet_requires_group_by(self): + """Test that facet mode requires group_by parameter""" + with self.assertRaisesRegex( + ValueError, + "group_by must be specified when facet=True." + ): + histogram( + self.adata, + feature='marker1', + facet=True, + ) + + def test_facet_conflicts_with_together_true(self): + """Test that facet mode conflicts with together=True""" + with self.assertRaisesRegex( + ValueError, + "Cannot use together=True with facet=True" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + facet=True, + ) + def test_facet_plot_smoke_and_structure(self): """Facet path returns expected structure and plotted content.""" fig, ax, df = histogram( @@ -538,6 +564,30 @@ def test_facet_plot_density_stat_label_policy(self): self.assertIsNotNone(fig._supylabel) self.assertEqual(fig._supylabel.get_text(), 'Density') + def test_facet_plot_categorical_annotation(self): + """Test facet mode with categorical annotations""" + fig, axs, df = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + + # Ensure axs is iterable for consistent handling + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + expected_groups = self.adata.obs['annotation2'].dropna().nunique() + self.assertEqual(len(axs), expected_groups) + + # Check that data is plotted in each facet + for axis in axs: + self.assertGreater(len(axis.patches), 0) + + # Check figure-level labels are set appropriately + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'annotation1') + self.assertEqual(fig._supylabel.get_text(), 'Count') + if __name__ == '__main__': unittest.main() From dd4bd72ab7600a2e7b5beb98562ec940b182ff79 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Fri, 17 Apr 2026 22:00:17 -0400 Subject: [PATCH 32/57] refactor(histogram): relocate histogram helpers into function scope - Move _compute_global_bin_edges into histogram as local function - Move _resolve_histogram_axis_labels into histogram as local function - Rename _parse_histogram_layout_kwargs to _parse_facet_layout_hints --- src/spac/visualization.py | 106 +++++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 12893883..254fa04f 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -404,32 +404,7 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): return fig, ax -def _compute_global_bin_edges(data_series, bins) -> Union[np.ndarray, pd.Index]: - """Compute global bin edges for a data series based on its type and the - bins parameter. - - This helper ensures that when creating multiple histograms (e.g., in - together mode or facet plots), all subplots use the same bin boundaries - for proper visual comparison. - - Parameters - ---------- - data_series (pd.Series): The data to compute bin edges for. - bins (int or sequence): Number of bins (for numeric data) or bin specification. - - Returns - ------- - array-like: - Bin edges for numeric data (array of boundary values), - or unique categories for categorical data (array of category labels). - """ - if pd.api.types.is_numeric_dtype(data_series): - return np.histogram_bin_edges(data_series, bins=bins) - else: - return data_series.unique() - - -def _parse_histogram_layout_kwargs(kwargs): +def _parse_facet_layout_hints(kwargs): """Extract histogram-internal layout hints and strip non-histplot keys. This parser is intentionally permissive for direct API usage: values that @@ -515,23 +490,6 @@ def _derive_facet_geometry( return facet_ncol, facet_height, facet_aspect -def _resolve_histogram_axis_labels(data_column, x_log_scale, y_log_scale, stat): - """Resolve histogram axis labels from scaling and stat settings.""" - xlabel = f'log({data_column})' if x_log_scale else data_column - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability', - "proportion": "Proportion", - "percent": "Percent" - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - return xlabel, ylabel - - def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, x_log_scale=False, y_log_scale=False, facet=False, **kwargs): @@ -744,7 +702,7 @@ def cal_bin_num( # Parse histogram-internal layout kwargs and remove them from kwargs # so they never leak to seaborn's histplot calls. facet_ncol, facet_fig_width, facet_fig_height = ( - _parse_histogram_layout_kwargs(kwargs) + _parse_facet_layout_hints(kwargs) ) # Function to calculate histogram data @@ -796,6 +754,60 @@ def calculate_histogram(data, bins, bin_edges=None): 'count': counts.values }) + # Function to compute shared bin edges for grouped histograms + def compute_global_bin_edges(data_series, bins): + """Compute shared bin boundaries for grouped histogram paths. + + Parameters + ---------- + data_series : pandas.Series + Data column used to derive shared bins. + bins : int or sequence + Bin specification forwarded to numpy/seaborn logic. + + Returns + ------- + numpy.ndarray or pandas.Index + Numeric bin edges, or categorical labels for non-numeric data. + """ + if pd.api.types.is_numeric_dtype(data_series): + return np.histogram_bin_edges(data_series, bins=bins) + return data_series.unique() + + # Function to get axis labels based on log scale and stat parameters + def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): + """Resolve x/y axis labels for histogram rendering. + + Parameters + ---------- + data_column : str + Source column used on the x axis. + x_log_scale : bool + Whether x data has log transform semantics. + y_log_scale : bool + Whether y axis is displayed on log scale. + stat : str + Histogram statistic mode (for example, count, density). + + Returns + ------- + tuple[str, str] + Resolved x-axis and y-axis labels. + """ + xlabel = f'log({data_column})' if x_log_scale else data_column + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability', + "proportion": "Proportion", + "percent": "Percent" + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + return xlabel, ylabel + # Plotting with or without grouping if group_by: groups = df[group_by].dropna().unique().tolist() @@ -810,7 +822,7 @@ def calculate_histogram(data, bins, bin_edges=None): fig, ax = plt.subplots() # Compute global bin edges based on the entire dataset - global_bin_edges = _compute_global_bin_edges( + global_bin_edges = compute_global_bin_edges( plot_data[data_column], kwargs['bins'] ) @@ -874,7 +886,7 @@ def calculate_histogram(data, bins, bin_edges=None): ) # Compute global bins so all facets use consistent boundaries. - global_bin_edges = _compute_global_bin_edges( + global_bin_edges = compute_global_bin_edges( plot_data[data_column], kwargs['bins'] ) @@ -950,7 +962,7 @@ def calculate_histogram(data, bins, bin_edges=None): axs.append(ax) stat = kwargs.get('stat', 'count') - xlabel, ylabel = _resolve_histogram_axis_labels( + xlabel, ylabel = resolve_hist_axis_labels( data_column=data_column, x_log_scale=x_log_scale, y_log_scale=y_log_scale, From 827d654c20e0cfae07b8f73312a04e8248ac27c5 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sat, 18 Apr 2026 15:13:11 -0400 Subject: [PATCH 33/57] docs(histogram): clarify facet figure size layout hints in template - Adds docstring comment about figure size behavior in facet mode --- src/spac/templates/histogram_template.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index b7c544ec..69f18d49 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -260,6 +260,9 @@ def run_from_json( else: x_var = feature + # In facet mode, Figure_Width/Height are passed as layout hints so + # visualization can derive panel geometry from total figure size: + # panel_width = Figure_Width / ncol, panel_height = Figure_Height / nrow. result = histogram( adata=adata, feature=feature, From 02dd9ba5cf2512c4ce9330809c7ac8edfd435277 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sat, 18 Apr 2026 15:15:53 -0400 Subject: [PATCH 34/57] refactor(histogram): improve code organization of facet geometry determination --- src/spac/visualization.py | 96 +++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 254fa04f..6724f9a6 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -404,16 +404,39 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): return fig, ax -def _parse_facet_layout_hints(kwargs): - """Extract histogram-internal layout hints and strip non-histplot keys. +def _derive_facet_geometry( + n_groups, + facet_ncol, + facet_fig_width, + facet_fig_height, + vertical_threshold=4, + default_height=3.2, + default_aspect=1.25, + min_panel_width=1.8, + min_panel_height=1.6, + min_aspect=0.6, + max_aspect=2.0, +): + """Normalize facet hints and derive FacetGrid layout geometry. - This parser is intentionally permissive for direct API usage: values that - cannot be interpreted are normalized to None so downstream auto-layout can - take over. + Returns + ------- + tuple + (facet_ncol, facet_height, facet_aspect, + facet_fig_width, facet_fig_height) + where facet_fig_width/height are normalized positive floats or None. """ - facet_ncol = kwargs.pop('facet_ncol', None) - facet_fig_width = kwargs.pop('facet_fig_width', None) - facet_fig_height = kwargs.pop('facet_fig_height', None) + def _normalize_positive_float(value): + """Normalize numeric layout hints to positive floats or None.""" + if isinstance(value, str): + value_str = value.strip().lower() + if value_str in {'auto', 'none', ''}: + return None + try: + value = float(value) + except (TypeError, ValueError): + return None + return value if value > 0 else None # Normalize only; template-level validation handles strict checks. if isinstance(facet_ncol, str): @@ -434,28 +457,8 @@ def _parse_facet_layout_hints(kwargs): if facet_ncol <= 0: facet_ncol = None - return facet_ncol, facet_fig_width, facet_fig_height - - -def _derive_facet_geometry( - n_groups, - facet_ncol, - facet_fig_width, - facet_fig_height, - vertical_threshold=4, - default_height=3.2, - default_aspect=1.25, - min_panel_width=1.8, - min_panel_height=1.6, - min_aspect=0.6, - max_aspect=2.0, -): - """Derive FacetGrid column count and panel geometry.""" - if facet_ncol is not None: - try: - facet_ncol = int(facet_ncol) - except (TypeError, ValueError): - facet_ncol = None + facet_fig_width = _normalize_positive_float(facet_fig_width) + facet_fig_height = _normalize_positive_float(facet_fig_height) if facet_ncol is None: if n_groups <= vertical_threshold: @@ -467,14 +470,6 @@ def _derive_facet_geometry( facet_height = default_height facet_aspect = default_aspect - if facet_fig_width is not None and facet_fig_height is not None: - try: - facet_fig_width = float(facet_fig_width) - facet_fig_height = float(facet_fig_height) - except (TypeError, ValueError): - facet_fig_width = None - facet_fig_height = None - if ( facet_fig_width is not None and facet_fig_height is not None and @@ -487,7 +482,13 @@ def _derive_facet_geometry( facet_height = panel_height facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) - return facet_ncol, facet_height, facet_aspect + return ( + facet_ncol, + facet_height, + facet_aspect, + facet_fig_width, + facet_fig_height, + ) def histogram(adata, feature=None, annotation=None, layer=None, @@ -699,11 +700,10 @@ def cal_bin_num( raise ValueError("Cannot use together=True with facet=True," " choose one.") - # Parse histogram-internal layout kwargs and remove them from kwargs - # so they never leak to seaborn's histplot calls. - facet_ncol, facet_fig_width, facet_fig_height = ( - _parse_facet_layout_hints(kwargs) - ) + # Remove histogram-internal layout hints so they never leak to seaborn. + facet_ncol = kwargs.pop('facet_ncol', None) + facet_fig_width = kwargs.pop('facet_fig_width', None) + facet_fig_height = kwargs.pop('facet_fig_height', None) # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): @@ -878,7 +878,13 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): else: # Facet option # Derive facet geometry based on group count and layout hints - facet_ncol, facet_height, facet_aspect = _derive_facet_geometry( + ( + facet_ncol, + facet_height, + facet_aspect, + facet_fig_width, + facet_fig_height, + ) = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, facet_fig_width=facet_fig_width, From b8e3358a82962eec1d55d154062730762734176d Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sat, 18 Apr 2026 15:18:00 -0400 Subject: [PATCH 35/57] test(histogram): expand facet coverage with comprehensive test suite - Expand validation of bins calculation to default-like input - Add validation of facet layout hints: facet_ncol, facet_fig_width/height - Add validation of bins consistency across facets in both numerical and categorical cases --- tests/test_visualization/test_histogram.py | 230 ++++++++++++++++++++- 1 file changed, 223 insertions(+), 7 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 2b419d22..3ec36f69 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -451,17 +451,26 @@ def test_histogram_feature_integer_bins(self): self.assertIsInstance(ax, mpl.axes.Axes) def test_default_bins_calculation(self): + """No bins argument should use Rice-rule fallback.""" + expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) + # No bins parameter passed fig, ax, df = histogram(self.adata, feature='marker1').values() + self.assertEqual(len(ax.patches), expected_bins) - # Count the number of bins - bars = ax.patches - n_bins = len(bars) - - # Validate the number of bins based on default bin calculation logic - # Using 2 * (n ** 1/3) heuristic for default bins + def test_default_like_bins_calculation(self): + """Default-like bins values should use Rice-rule fallback.""" expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) - self.assertEqual(n_bins, expected_bins) + + for bins_value in [None, 'auto', 'none', '']: + with self.subTest(bins=bins_value): + fig, ax, df = histogram( + self.adata, + feature='marker1', + bins=bins_value, + ).values() + + self.assertEqual(len(ax.patches), expected_bins) def test_facet_requires_group_by(self): """Test that facet mode requires group_by parameter""" @@ -588,6 +597,213 @@ def test_facet_plot_categorical_annotation(self): self.assertEqual(fig._supxlabel.get_text(), 'annotation1') self.assertEqual(fig._supylabel.get_text(), 'Count') + def test_facet_ncol_layout_hints(self): + """Facet ncol supports positive int and documented auto behavior.""" + X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + { + 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(9)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + # Explicit two-column layout should create two facet columns. + fig, axs, _ = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol=2, + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + x_positions = {round(axis.get_position().x0, 4) for axis in axs} + self.assertGreaterEqual(len(x_positions), 2) + + # Documented default-like input should use auto layout (one column for 3 groups). + fig, axs, _ = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol='auto', + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + x_positions = {round(axis.get_position().x0, 4) for axis in axs} + self.assertEqual(len(x_positions), 1) + + # Keep one lightweight guardrail check for invalid fallback behavior. + fig, axs, _ = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol='bad', + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + x_positions = {round(axis.get_position().x0, 4) for axis in axs} + self.assertEqual(len(x_positions), 1) + + def test_facet_figure_size_hints(self): + """Facet figure-size hints should accept valid values and sanitize invalid ones.""" + X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + { + 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(9)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + # Check that valid figure size hints are applied to the facet figure. + fig, _, _ = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=11, + facet_fig_height=3.5, + ).values() + self.assertAlmostEqual(fig.get_figwidth(), 11.0, places=2) + self.assertAlmostEqual(fig.get_figheight(), 3.5, places=2) + + # Check that invalid size hints are sanitized + for width, height in [('wide', 'tall'), (-1, 0)]: + with self.subTest(facet_fig_width=width, facet_fig_height=height): + fig, _, _ = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=width, + facet_fig_height=height, + ).values() + self.assertGreater(fig.get_figwidth(), 0) + self.assertGreater(fig.get_figheight(), 0) + + def test_facet_plot_shared_bins_consistency_numeric(self): + """Numeric facets keep shared bins for int/default-like bins inputs.""" + # Unbalanced groups: each group occupies only part of the global range. + # If bins are computed per-group (bad path), centers/ticks may diverge. + X = np.array([[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]], + dtype=np.float32) + obs = pd.DataFrame( + { + 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2'], + }, + index=[f'cell_{i}' for i in range(6)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + # Test one explicit and one default-like bins path. + for bins_value in [4, None]: + with self.subTest(bins=bins_value): + fig, axs, df = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + bins=bins_value, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + + # Capture the bin centers and ticks from the first facet + first_xlim = np.round(np.array(axs[0].get_xlim()), 6) + first_xticks = np.round(np.array(axs[0].get_xticks()), 6) + first_yticks = np.round(np.array(axs[0].get_yticks()), 6) + first_centers = np.round( + np.array([ + patch.get_x() + patch.get_width() / 2 + for patch in axs[0].patches + ]), + 6 + ) + + # Check that all facets have the same bin centers and ticks + for axis in axs[1:]: + centers = np.round( + np.array([ + patch.get_x() + patch.get_width() / 2 + for patch in axis.patches + ]), + 6 + ) + self.assertTrue( + np.array_equal(centers, first_centers), + "Facet numeric bin centers should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_xlim()), 6), first_xlim), + "Facet numeric x-limits should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_xticks()), 6), first_xticks), + "Facet numeric x-ticks should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), + "Facet numeric y-ticks should remain shared across panels." + ) + + def test_facet_plot_shared_bins_consistency_categorical(self): + """Facet categorical bins stay aligned even with missing labels.""" + # Build unbalanced facet groups where some labels are missing per group. + X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + { + 'annotation1': pd.Categorical( + ['A', 'A', 'B', 'A', 'C', 'C', 'B', 'C', 'A'], + categories=['A', 'B', 'C'], + ), + 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(9)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + fig, axs, df = histogram( + adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + + # Guardrail: this fixture must include missing labels per group. + group_uniques = adata.obs.groupby('annotation2')['annotation1'].nunique() + self.assertTrue(any(group_uniques < 3)) + + # Check that bin centers are shared across facets + global_centers = set() + for axis in axs: + global_centers.update( + np.round( + [patch.get_x() + patch.get_width() / 2 + for patch in axis.patches], + 6, + ) + ) + self.assertEqual( + len(global_centers), + 3, + "Expected 3 categorical slots (A/B/C) to be preserved globally." + ) + + # Check that ticks are shared across facets + first_xticks = np.round(axs[0].get_xticks(), 6) + first_yticks = np.round(np.array(axs[0].get_yticks()), 6) + for axis in axs[1:]: + self.assertTrue(np.array_equal(np.round(axis.get_xticks(), 6), first_xticks)) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), + "Facet categorical y-ticks should remain shared across panels." + ) if __name__ == '__main__': unittest.main() From 083315b3a5fb1be653834b41fc74bbd08e562681 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sat, 18 Apr 2026 23:28:01 -0400 Subject: [PATCH 36/57] feat(utils): add normalize_positive_number helper function - New reusable utility for normalizing numeric parameters with default-like value handling. - Supports float/int conversion with guardrails for positive values. --- src/spac/utils.py | 111 +++++++++++++++--- .../test_normalize_positive_number.py | 59 ++++++++++ 2 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 tests/test_utils/test_normalize_positive_number.py diff --git a/src/spac/utils.py b/src/spac/utils.py index 2a616b68..87b839dc 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -5,14 +5,15 @@ import matplotlib.cm as cm import pandas as pd import logging -import warnings -import numbers -from scipy.stats import median_abs_deviation -from typing import List, Optional - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s') +import warnings +import numbers +from scipy.stats import median_abs_deviation +from typing import Any, List, Optional + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) def regex_search_list( @@ -417,9 +418,9 @@ def check_distances(distances): f"Got {distances}") -def text_to_others( - parameter, - text="None", +def text_to_others( + parameter, + text="None", to_None=True, to_False=False, to_True=False, @@ -467,8 +468,90 @@ def check_same( if to_Float: parameter = float(parameter) - - return parameter + + return parameter + + +def normalize_positive_number( + parameter: Any, + var_name: str = "parameter", + convert_to: str = "float", + default_like_values=("auto", "none", ""), +): + """Normalize a value to a positive float/int or None. + + Parameters + ---------- + parameter : any + Value to normalize. Strings are matched against default-like tokens + before numeric conversion. + var_name : str, optional + Name used in log messages. + convert_to : str, optional + Target numeric type. Supported values are ``"float"`` and ``"int"``. + default_like_values : tuple of str, optional + Lowercased text tokens that should be treated as missing values. + + Returns + ------- + float or int or None + Positive converted value, or ``None`` when the input is default-like, + invalid, or non-positive. + """ + if parameter is None or isinstance(parameter, bool): + logger.info( + "%s=%r is treated as missing input. Falling back to automatic " + "behavior.", + var_name, + parameter, + ) + return None + + if isinstance(parameter, str): + parameter_str = parameter.strip().lower() + if parameter_str in default_like_values: + logger.info( + "%s=%r is treated as default-like input. Falling back to " + "automatic behavior.", + var_name, + parameter, + ) + return None + + try: + if convert_to == "float": + parameter = float(parameter) + elif convert_to == "int": + parameter = int(parameter) + else: + logger.warning( + "%s uses unsupported conversion '%s'. Falling back to " + "automatic behavior.", + var_name, + convert_to, + ) + return None + except (TypeError, ValueError): + logger.warning( + "Could not convert %s=%r to %s. Falling back to automatic " + "behavior.", + var_name, + parameter, + convert_to, + ) + return None + + if parameter <= 0: + logger.warning( + "%s=%r is not a positive %s. Falling back to automatic " + "behavior.", + var_name, + parameter, + convert_to, + ) + return None + + return parameter def annotation_category_relations( @@ -1273,4 +1356,4 @@ def compute_summary_qc_stats( "upper_mad", "lower_mad", "upper_quantile", "lower_quantile" ] - ) \ No newline at end of file + ) diff --git a/tests/test_utils/test_normalize_positive_number.py b/tests/test_utils/test_normalize_positive_number.py new file mode 100644 index 00000000..2b3ac0df --- /dev/null +++ b/tests/test_utils/test_normalize_positive_number.py @@ -0,0 +1,59 @@ +import unittest + +from spac.utils import normalize_positive_number + + +class TestNormalizePositiveNumber(unittest.TestCase): + def test_float_conversion(self): + self.assertEqual( + normalize_positive_number("11.5", convert_to="float"), + 11.5, + ) + + def test_int_conversion(self): + self.assertEqual( + normalize_positive_number("3", convert_to="int"), + 3, + ) + + def test_default_like_values_return_none(self): + self.assertIsNone(normalize_positive_number("auto", convert_to="int")) + self.assertIsNone(normalize_positive_number("None", convert_to="float")) + self.assertIsNone(normalize_positive_number(None, convert_to="float")) + + def test_invalid_or_non_positive_values_return_none(self): + self.assertIsNone(normalize_positive_number("bad", convert_to="int")) + self.assertIsNone(normalize_positive_number("-1", convert_to="float")) + self.assertIsNone(normalize_positive_number(0, convert_to="float")) + + def test_sanitized_inputs_are_logged(self): + with self.assertLogs("spac.utils", level="INFO") as logs: + self.assertIsNone( + normalize_positive_number( + "auto", + var_name="facet_ncol", + convert_to="int", + ) + ) + + self.assertTrue( + any("facet_ncol='auto'" in message for message in logs.output) + ) + + def test_invalid_inputs_are_logged_as_warning(self): + with self.assertLogs("spac.utils", level="WARNING") as logs: + self.assertIsNone( + normalize_positive_number( + "bad", + var_name="facet_fig_width", + convert_to="float", + ) + ) + + self.assertTrue( + any("facet_fig_width='bad'" in message for message in logs.output) + ) + + +if __name__ == "__main__": + unittest.main() From 6914244845f1138e64d2d27a5b5ea15d1fb9c2ef Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sat, 18 Apr 2026 23:30:45 -0400 Subject: [PATCH 37/57] refactor(facet): refactor facet geometry using normalize_positive_number - Simplify _derive_facet_geometry by delegating numeric parameter normalization to the new helper. - Improves code clarity and reduces duplication. - Adds comprehensive test coverage for geometry derivation. --- src/spac/visualization.py | 154 +++++++++++------- .../test_derive_facet_geometry.py | 99 +++++++++++ 2 files changed, 190 insertions(+), 63 deletions(-) create mode 100644 tests/test_visualization/test_derive_facet_geometry.py diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 6724f9a6..4c7ba6db 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -17,6 +17,7 @@ from spac.utils import check_label from spac.utils import get_defined_color_map from spac.utils import compute_boxplot_metrics +from spac.utils import normalize_positive_number from functools import partial from spac.utils import color_mapping, spell_out_special_characters from spac.data_utils import select_values @@ -406,10 +407,10 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): def _derive_facet_geometry( n_groups, - facet_ncol, - facet_fig_width, - facet_fig_height, - vertical_threshold=4, + facet_ncol="auto", + facet_fig_width=None, + facet_fig_height=None, + vertical_threshold=3, default_height=3.2, default_aspect=1.25, min_panel_width=1.8, @@ -417,78 +418,106 @@ def _derive_facet_geometry( min_aspect=0.6, max_aspect=2.0, ): - """Normalize facet hints and derive FacetGrid layout geometry. + """Normalize facet layout hints and derive FacetGrid geometry. + + Parameters + ---------- + n_groups : int + Number of facet panels. Expected to be a positive integer supplied by + the grouped histogram path. + facet_ncol : int or "auto", optional + Requested facet column count. Positive integers are used directly. + Default-like string inputs (`"auto"`, `"none"`, `""`) and invalid + values fall back to automatic column selection. + facet_fig_width, facet_fig_height : float, optional + Optional total figure-size hints. Positive numeric values are kept. + Default-like or invalid values are sanitized to ``None``. Geometry is + derived from these hints only when both normalized values are present. + vertical_threshold : int, optional + Maximum group count that still prefers a single-column automatic + layout. + default_height : float, optional + Per-facet panel height used when explicit figure-size hints are not + available. + default_aspect : float, optional + Per-facet panel aspect ratio used when explicit figure-size hints are + not available. + min_panel_width, min_panel_height : float, optional + Lower bounds applied to per-panel dimensions when figure-size hints are + converted into FacetGrid geometry. + min_aspect, max_aspect : float, optional + Bounds applied to the derived panel aspect ratio. Returns ------- - tuple - (facet_ncol, facet_height, facet_aspect, - facet_fig_width, facet_fig_height) - where facet_fig_width/height are normalized positive floats or None. + dict + Dictionary with keys: + - ``facet_ncol``: normalized column count clamped to ``n_groups``; + - ``facet_height`` / ``facet_aspect``: FacetGrid-ready per-panel + geometry values; + - ``facet_fig_width`` / ``facet_fig_height``: normalized positive + floats or ``None``. + + Automatic layout uses one column when ``n_groups <= vertical_threshold`` + and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both + normalized figure-size hints are present, the helper converts total + figure size into per-panel geometry using the derived grid shape, + applies the minimum panel-size guardrails, and clips aspect into the + configured range. """ - def _normalize_positive_float(value): - """Normalize numeric layout hints to positive floats or None.""" - if isinstance(value, str): - value_str = value.strip().lower() - if value_str in {'auto', 'none', ''}: - return None - try: - value = float(value) - except (TypeError, ValueError): - return None - return value if value > 0 else None - - # Normalize only; template-level validation handles strict checks. - if isinstance(facet_ncol, str): - facet_ncol_str = facet_ncol.strip().lower() - if facet_ncol_str in {'auto', 'none', ''}: - facet_ncol = None - else: - try: - facet_ncol = int(facet_ncol) - except ValueError: - facet_ncol = None - - if facet_ncol is not None: - try: - facet_ncol = int(facet_ncol) - except (TypeError, ValueError): - facet_ncol = None - if facet_ncol <= 0: - facet_ncol = None - - facet_fig_width = _normalize_positive_float(facet_fig_width) - facet_fig_height = _normalize_positive_float(facet_fig_height) + # Normalize facet_ncol and apply automatic logic if needed + facet_ncol = normalize_positive_number( + facet_ncol, + var_name="facet_ncol", + convert_to="int", + default_like_values=("auto", "none", ""), + ) if facet_ncol is None: if n_groups <= vertical_threshold: facet_ncol = 1 else: facet_ncol = int(np.ceil(np.sqrt(n_groups))) - + logging.info( + "Automatic facet_ncol selection: %s columns for %s groups " + "(vertical_threshold=%s).", + facet_ncol, + n_groups, + vertical_threshold, + ) facet_ncol = max(1, min(int(facet_ncol), n_groups)) + + # Use defaults if figure-size hints are not provided facet_height = default_height facet_aspect = default_aspect - if ( - facet_fig_width is not None and - facet_fig_height is not None and - facet_fig_width > 0 and - facet_fig_height > 0 - ): + # Normalize figure-size hints + facet_fig_width = normalize_positive_number( + facet_fig_width, + var_name="facet_fig_width", + convert_to="float", + ) + facet_fig_height = normalize_positive_number( + facet_fig_height, + var_name="facet_fig_height", + convert_to="float", + ) + + # Derive facet geometry from figure-size hints when both are provided + if facet_fig_width is not None and facet_fig_height is not None: nrow = int(np.ceil(n_groups / facet_ncol)) panel_width = max(facet_fig_width / facet_ncol, min_panel_width) panel_height = max(facet_fig_height / nrow, min_panel_height) facet_height = panel_height facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) - return ( - facet_ncol, - facet_height, - facet_aspect, - facet_fig_width, - facet_fig_height, - ) + return { + "facet_ncol": facet_ncol, + "facet_height": facet_height, + "facet_aspect": facet_aspect, + "facet_fig_width": facet_fig_width, + "facet_fig_height": facet_fig_height, + } def histogram(adata, feature=None, annotation=None, layer=None, @@ -878,18 +907,17 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): else: # Facet option # Derive facet geometry based on group count and layout hints - ( - facet_ncol, - facet_height, - facet_aspect, - facet_fig_width, - facet_fig_height, - ) = _derive_facet_geometry( + facet_layout = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, facet_fig_width=facet_fig_width, facet_fig_height=facet_fig_height, ) + facet_ncol = facet_layout["facet_ncol"] + facet_height = facet_layout["facet_height"] + facet_aspect = facet_layout["facet_aspect"] + facet_fig_width = facet_layout["facet_fig_width"] + facet_fig_height = facet_layout["facet_fig_height"] # Compute global bins so all facets use consistent boundaries. global_bin_edges = compute_global_bin_edges( diff --git a/tests/test_visualization/test_derive_facet_geometry.py b/tests/test_visualization/test_derive_facet_geometry.py new file mode 100644 index 00000000..4f694e05 --- /dev/null +++ b/tests/test_visualization/test_derive_facet_geometry.py @@ -0,0 +1,99 @@ +import unittest + +from spac.visualization import _derive_facet_geometry + + +class TestDeriveFacetGeometry(unittest.TestCase): + def test_minimal_single_group_defaults(self): + """Single-group input should keep one column and default geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=1, + default_height=3.2, + default_aspect=1.25, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + self.assertIsNone(facet_layout["facet_fig_width"]) + self.assertIsNone(facet_layout["facet_fig_height"]) + + def test_auto_layout_uses_single_column_below_threshold(self): + """Check that auto layout selects 1 column when n_groups is at or below threshold.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol="auto", + vertical_threshold=5, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + + def test_auto_layout_uses_sqrt_rule_above_threshold(self): + """Auto layout should use sqrt rule when n_groups is above threshold.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol="auto", + vertical_threshold=3, + ) + + self.assertEqual(facet_layout["facet_ncol"], 3) + + def test_explicit_column_count_and_figure_size_hints_drive_geometry(self): + """Explicit facet_ncol should be used directly to compute geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol=2, + vertical_threshold=5, + facet_fig_width=11, + facet_fig_height=4, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + self.assertAlmostEqual(facet_layout["facet_fig_width"], 11.0) + self.assertAlmostEqual(facet_layout["facet_fig_height"], 4.0) + self.assertAlmostEqual(facet_layout["facet_height"], 1.6) + self.assertAlmostEqual(facet_layout["facet_aspect"], 2.0) + + def test_single_figure_size_hint_falls_back_to_defaults(self): + """A one-sided size hint should not partially derive facet geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=4, + facet_fig_width=11, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + self.assertEqual(facet_layout["facet_fig_width"], 11.0) + self.assertIsNone(facet_layout["facet_fig_height"]) + + def test_invalid_inputs_fall_back_to_auto_and_sanitize_size_hints(self): + """Invalid facet_ncol and figure size hints should be sanitized to auto behavior.""" + facet_layout = _derive_facet_geometry( + n_groups=3, + facet_ncol="bad", + facet_fig_width="wide", + facet_fig_height=0, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + self.assertIsNone(facet_layout["facet_fig_width"]) + self.assertIsNone(facet_layout["facet_fig_height"]) + + def test_explicit_column_count_is_clamped_to_group_count(self): + """Explicit facet_ncol should be clamped to n_groups if it exceeds it.""" + facet_layout = _derive_facet_geometry( + n_groups=2, + facet_ncol=10, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + + +if __name__ == "__main__": + unittest.main() From 6af08200a241cf2cc05ef65f997a283ce40cc405 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sun, 19 Apr 2026 20:27:22 -0400 Subject: [PATCH 38/57] fix(histogram-template): correct rotated tick label handling in facets - template: rotate actual tick labels (anchor/right) - remove _supxlabel rotation branch --- src/spac/templates/histogram_template.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 69f18d49..0bfd2f10 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -325,17 +325,10 @@ def run_from_json( # Rotate x labels ax.tick_params(axis='x', rotation=x_rotate) - - # Process figure-level xlabel for faceted plots - if facet: - facet_xlabel = getattr(fig, '_supxlabel', None) - if facet_xlabel is None: - logger.warning( - "Facet xlabel not found. X label rotation will " - "not be applied." - ) - else: - facet_xlabel.set_rotation(x_rotate) + if x_rotate: + for label in ax.get_xticklabels(): + label.set_rotation_mode('anchor') + label.set_horizontalalignment('right') # Set titles based on group_by and facet if text_to_value(group_by): From 5c93459826fb2bc8478ae9e38d0380d7bc33f696 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sun, 19 Apr 2026 20:56:11 -0400 Subject: [PATCH 39/57] fix(histogram): prune facet layout hints validation - validate facet hint inputs explicitly in histogram() - require facet_fig_width and facet_fig_height to be provided together - simplify _derive_facet_geometry() to consume pre-normalized hints - validate positive figure_width/height in histogram and its template - default automatic facet_ncol to None in histogram and its template - update facet geometry and histogram tests to match strict validation - simplify tests for facet hints to use shared self.adata fixtures --- src/spac/templates/histogram_template.py | 23 ++- src/spac/visualization.py | 133 +++++++++++------- .../test_derive_facet_geometry.py | 26 ++-- tests/test_visualization/test_histogram.py | 76 +++++----- 4 files changed, 155 insertions(+), 103 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 0bfd2f10..5d9d2df6 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -217,6 +217,23 @@ def run_from_json( f'Received "{stat}".' ) + # validate figure size parameters + fig_width = text_to_value( + fig_width, + to_float=True, + param_name="Figure_Width" + ) + fig_height = text_to_value( + fig_height, + to_float=True, + param_name="Figure_Height" + ) + if fig_width <= 0 or fig_height <= 0: + raise ValueError( + f'Figure_Width/Height should be a positive number.' + f'Received "{fig_width}"/"{fig_height}".' + ) + # Validate x-axis label rotation if (x_rotate < 0) or (x_rotate > 360): raise ValueError( @@ -236,13 +253,13 @@ def run_from_json( 'Together and Facet cannot both be True. Please set one to False.' ) - # Validate facet_ncol, allowing for "auto" or positive integers + # Validate and canonicalize facet_ncol, allowing for "auto" or positive integers facet_ncol = text_to_value( facet_ncol, default_none_text="auto", - value_to_convert_to="auto" + value_to_convert_to=None ) - if facet_ncol != "auto": + if facet_ncol is not None: facet_ncol = text_to_value( facet_ncol, to_int=True, diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 4c7ba6db..97e4e513 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -407,7 +407,7 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): def _derive_facet_geometry( n_groups, - facet_ncol="auto", + facet_ncol=None, facet_fig_width=None, facet_fig_height=None, vertical_threshold=3, @@ -418,21 +418,19 @@ def _derive_facet_geometry( min_aspect=0.6, max_aspect=2.0, ): - """Normalize facet layout hints and derive FacetGrid geometry. + """Derive FacetGrid geometry from pre-normalized facet layout hints. Parameters ---------- n_groups : int Number of facet panels. Expected to be a positive integer supplied by the grouped histogram path. - facet_ncol : int or "auto", optional + facet_ncol : int or None, optional Requested facet column count. Positive integers are used directly. - Default-like string inputs (`"auto"`, `"none"`, `""`) and invalid - values fall back to automatic column selection. + ``None`` falls back to automatic column selection. facet_fig_width, facet_fig_height : float, optional - Optional total figure-size hints. Positive numeric values are kept. - Default-like or invalid values are sanitized to ``None``. Geometry is - derived from these hints only when both normalized values are present. + Optional total figure-size hints. Geometry is derived from these hints only + when both values are present. vertical_threshold : int, optional Maximum group count that still prefers a single-column automatic layout. @@ -452,11 +450,9 @@ def _derive_facet_geometry( ------- dict Dictionary with keys: - - ``facet_ncol``: normalized column count clamped to ``n_groups``; - - ``facet_height`` / ``facet_aspect``: FacetGrid-ready per-panel - geometry values; - - ``facet_fig_width`` / ``facet_fig_height``: normalized positive - floats or ``None``. + - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; + - ``facet_height``: float, FacetGrid-ready per-panel height in inches; + - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. Automatic layout uses one column when ``n_groups <= vertical_threshold`` and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both @@ -466,13 +462,7 @@ def _derive_facet_geometry( configured range. """ - # Normalize facet_ncol and apply automatic logic if needed - facet_ncol = normalize_positive_number( - facet_ncol, - var_name="facet_ncol", - convert_to="int", - default_like_values=("auto", "none", ""), - ) + # Derive facet_ncol when not explicitly provided, and clamp to n_groups if facet_ncol is None: if n_groups <= vertical_threshold: facet_ncol = 1 @@ -491,18 +481,6 @@ def _derive_facet_geometry( facet_height = default_height facet_aspect = default_aspect - # Normalize figure-size hints - facet_fig_width = normalize_positive_number( - facet_fig_width, - var_name="facet_fig_width", - convert_to="float", - ) - facet_fig_height = normalize_positive_number( - facet_fig_height, - var_name="facet_fig_height", - convert_to="float", - ) - # Derive facet geometry from figure-size hints when both are provided if facet_fig_width is not None and facet_fig_height is not None: nrow = int(np.ceil(n_groups / facet_ncol)) @@ -515,8 +493,6 @@ def _derive_facet_geometry( "facet_ncol": facet_ncol, "facet_height": facet_height, "facet_aspect": facet_aspect, - "facet_fig_width": facet_fig_width, - "facet_fig_height": facet_fig_height, } @@ -728,11 +704,72 @@ def cal_bin_num( if together: raise ValueError("Cannot use together=True with facet=True," " choose one.") - - # Remove histogram-internal layout hints so they never leak to seaborn. - facet_ncol = kwargs.pop('facet_ncol', None) - facet_fig_width = kwargs.pop('facet_fig_width', None) - facet_fig_height = kwargs.pop('facet_fig_height', None) + + def _parse_optional_number( + name, + value, + default_like_values=None, + to_type="float", + to_range=None, + to_default_value=None + ): + """Parse an optional numeric value with default-like string handling.""" + def _is_default_like(value, default_like_values=None): + if isinstance(value, str) and default_like_values is not None: + return value.strip().lower() in default_like_values + return False + + if value is None or _is_default_like(value, default_like_values): + return to_default_value + try: + if to_type == "float": + parsed = float(value) + elif to_type == "int": + parsed = int(value) + except (TypeError, ValueError): + raise ValueError( + f'{name} must be a positive {to_type}' + f'{" or one of default values. " if default_like_values else ". "}' + f'Received "{value}".' + ) + if not math.isfinite(parsed): + raise ValueError( + f'{name} must be a finite {to_type}. Received "{value}".' + ) + if to_range == "positive": + to_range = [float('1e-10'), float('inf')] + if isinstance(to_range, list): + min_val, max_val = to_range + if parsed < min_val or parsed > max_val: + raise ValueError( + f'{name} must be a {to_type} in the range [{min_val}, {max_val}].' + f' Received "{value}".' + ) + return parsed + + # Parse facet layout hints so they never leak to seaborn. + facet_ncol = _parse_optional_number( + "facet_ncol", + kwargs.pop('facet_ncol', None), + default_like_values={"", "auto", "none"}, + to_type="int", + to_range="positive", + ) + facet_fig_width = _parse_optional_number( + "facet_fig_width", + kwargs.pop('facet_fig_width', None), + to_range="positive", + ) + facet_fig_height = _parse_optional_number( + "facet_fig_height", + kwargs.pop('facet_fig_height', None), + to_range="positive", + ) + if (facet_fig_width is None) != (facet_fig_height is None): + raise ValueError( + "Both facet_fig_width and facet_fig_height must be provided together, " + "or both must be left as None." + ) # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): @@ -907,17 +944,13 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): else: # Facet option # Derive facet geometry based on group count and layout hints + # Keys include: facet_ncol, facet_height, facet_aspect facet_layout = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, facet_fig_width=facet_fig_width, facet_fig_height=facet_fig_height, ) - facet_ncol = facet_layout["facet_ncol"] - facet_height = facet_layout["facet_height"] - facet_aspect = facet_layout["facet_aspect"] - facet_fig_width = facet_layout["facet_fig_width"] - facet_fig_height = facet_layout["facet_fig_height"] # Compute global bins so all facets use consistent boundaries. global_bin_edges = compute_global_bin_edges( @@ -928,9 +961,9 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): hist = sns.FacetGrid( plot_data, col=group_by, - col_wrap=facet_ncol, - height=facet_height, - aspect=facet_aspect, + col_wrap=facet_layout['facet_ncol'], + height=facet_layout['facet_height'], + aspect=facet_layout['facet_aspect'], sharex=True, sharey=True, ) @@ -967,8 +1000,10 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): # Pass the figure and axes to the output for further customization fig = hist.figure - fig.set_size_inches(facet_fig_width or fig.get_figwidth(), - facet_fig_height or fig.get_figheight()) + fig.set_size_inches( + facet_fig_width or fig.get_figwidth(), + facet_fig_height or fig.get_figheight(), + ) axs.extend(hist.axes.flat) hist_data = plot_data diff --git a/tests/test_visualization/test_derive_facet_geometry.py b/tests/test_visualization/test_derive_facet_geometry.py index 4f694e05..17cebfad 100644 --- a/tests/test_visualization/test_derive_facet_geometry.py +++ b/tests/test_visualization/test_derive_facet_geometry.py @@ -15,14 +15,12 @@ def test_minimal_single_group_defaults(self): self.assertEqual(facet_layout["facet_ncol"], 1) self.assertEqual(facet_layout["facet_height"], 3.2) self.assertEqual(facet_layout["facet_aspect"], 1.25) - self.assertIsNone(facet_layout["facet_fig_width"]) - self.assertIsNone(facet_layout["facet_fig_height"]) def test_auto_layout_uses_single_column_below_threshold(self): """Check that auto layout selects 1 column when n_groups is at or below threshold.""" facet_layout = _derive_facet_geometry( n_groups=5, - facet_ncol="auto", + facet_ncol=None, vertical_threshold=5, ) @@ -32,7 +30,7 @@ def test_auto_layout_uses_sqrt_rule_above_threshold(self): """Auto layout should use sqrt rule when n_groups is above threshold.""" facet_layout = _derive_facet_geometry( n_groups=5, - facet_ncol="auto", + facet_ncol=None, vertical_threshold=3, ) @@ -49,8 +47,6 @@ def test_explicit_column_count_and_figure_size_hints_drive_geometry(self): ) self.assertEqual(facet_layout["facet_ncol"], 2) - self.assertAlmostEqual(facet_layout["facet_fig_width"], 11.0) - self.assertAlmostEqual(facet_layout["facet_fig_height"], 4.0) self.assertAlmostEqual(facet_layout["facet_height"], 1.6) self.assertAlmostEqual(facet_layout["facet_aspect"], 2.0) @@ -59,21 +55,22 @@ def test_single_figure_size_hint_falls_back_to_defaults(self): facet_layout = _derive_facet_geometry( n_groups=4, facet_fig_width=11, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, ) self.assertEqual(facet_layout["facet_ncol"], 2) self.assertEqual(facet_layout["facet_height"], 3.2) self.assertEqual(facet_layout["facet_aspect"], 1.25) - self.assertEqual(facet_layout["facet_fig_width"], 11.0) - self.assertIsNone(facet_layout["facet_fig_height"]) - def test_invalid_inputs_fall_back_to_auto_and_sanitize_size_hints(self): - """Invalid facet_ncol and figure size hints should be sanitized to auto behavior.""" + def test_none_inputs_fall_back_to_auto_and_default_geometry(self): + """Missing pre-normalized hints should use auto layout and defaults.""" facet_layout = _derive_facet_geometry( n_groups=3, - facet_ncol="bad", - facet_fig_width="wide", - facet_fig_height=0, + facet_ncol=None, + facet_fig_width=None, + facet_fig_height=None, vertical_threshold=3, default_height=3.2, default_aspect=1.25, @@ -82,14 +79,13 @@ def test_invalid_inputs_fall_back_to_auto_and_sanitize_size_hints(self): self.assertEqual(facet_layout["facet_ncol"], 1) self.assertEqual(facet_layout["facet_height"], 3.2) self.assertEqual(facet_layout["facet_aspect"], 1.25) - self.assertIsNone(facet_layout["facet_fig_width"]) - self.assertIsNone(facet_layout["facet_fig_height"]) def test_explicit_column_count_is_clamped_to_group_count(self): """Explicit facet_ncol should be clamped to n_groups if it exceeds it.""" facet_layout = _derive_facet_geometry( n_groups=2, facet_ncol=10, + vertical_threshold=3, ) self.assertEqual(facet_layout["facet_ncol"], 2) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 3ec36f69..e64ed55b 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -599,19 +599,9 @@ def test_facet_plot_categorical_annotation(self): def test_facet_ncol_layout_hints(self): """Facet ncol supports positive int and documented auto behavior.""" - X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) - obs = pd.DataFrame( - { - 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], - }, - index=[f'cell_{i}' for i in range(9)], - ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) - # Explicit two-column layout should create two facet columns. fig, axs, _ = histogram( - adata, + self.adata, feature='marker1', group_by='annotation2', facet=True, @@ -623,7 +613,7 @@ def test_facet_ncol_layout_hints(self): # Documented default-like input should use auto layout (one column for 3 groups). fig, axs, _ = histogram( - adata, + self.adata, feature='marker1', group_by='annotation2', facet=True, @@ -633,33 +623,29 @@ def test_facet_ncol_layout_hints(self): x_positions = {round(axis.get_position().x0, 4) for axis in axs} self.assertEqual(len(x_positions), 1) - # Keep one lightweight guardrail check for invalid fallback behavior. - fig, axs, _ = histogram( - adata, + # Invalid values should fail fast. + with self.assertRaises(ValueError): + histogram( + self.adata, feature='marker1', group_by='annotation2', facet=True, facet_ncol='bad', - ).values() - axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] - x_positions = {round(axis.get_position().x0, 4) for axis in axs} - self.assertEqual(len(x_positions), 1) + ) + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol=0, + ) def test_facet_figure_size_hints(self): """Facet figure-size hints should accept valid values and sanitize invalid ones.""" - X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) - obs = pd.DataFrame( - { - 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], - }, - index=[f'cell_{i}' for i in range(9)], - ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) - # Check that valid figure size hints are applied to the facet figure. fig, _, _ = histogram( - adata, + self.adata, feature='marker1', group_by='annotation2', facet=True, @@ -669,19 +655,37 @@ def test_facet_figure_size_hints(self): self.assertAlmostEqual(fig.get_figwidth(), 11.0, places=2) self.assertAlmostEqual(fig.get_figheight(), 3.5, places=2) - # Check that invalid size hints are sanitized + # Invalid hints should fail fast. for width, height in [('wide', 'tall'), (-1, 0)]: with self.subTest(facet_fig_width=width, facet_fig_height=height): - fig, _, _ = histogram( - adata, + with self.assertRaises(ValueError): + histogram( + self.adata, feature='marker1', group_by='annotation2', facet=True, facet_fig_width=width, facet_fig_height=height, - ).values() - self.assertGreater(fig.get_figwidth(), 0) - self.assertGreater(fig.get_figheight(), 0) + ) + + def test_facet_figure_size_hints_require_pair(self): + """One-sided facet figure-size hints should raise a ValueError.""" + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=11, + ) + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_height=3.5, + ) def test_facet_plot_shared_bins_consistency_numeric(self): """Numeric facets keep shared bins for int/default-like bins inputs.""" From 5143007dce85164f55e9178d8a440de5f662edeb Mon Sep 17 00:00:00 2001 From: Boqiang Date: Sun, 19 Apr 2026 22:23:15 -0400 Subject: [PATCH 40/57] refactor(utils): remove normalize_positive_number helper - delete normalize_positive_number from spac.utils - remove unused normalize_positive_number import in visualization - drop obsolete unit tests for removed helper --- src/spac/utils.py | 112 +++--------------- src/spac/visualization.py | 1 - .../test_normalize_positive_number.py | 59 --------- 3 files changed, 15 insertions(+), 157 deletions(-) delete mode 100644 tests/test_utils/test_normalize_positive_number.py diff --git a/src/spac/utils.py b/src/spac/utils.py index 87b839dc..08d89ee6 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -5,15 +5,15 @@ import matplotlib.cm as cm import pandas as pd import logging -import warnings -import numbers -from scipy.stats import median_abs_deviation -from typing import Any, List, Optional - -# Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +import warnings +import numbers +from scipy.stats import median_abs_deviation +from typing import Any, List, Optional + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) def regex_search_list( @@ -418,9 +418,9 @@ def check_distances(distances): f"Got {distances}") -def text_to_others( - parameter, - text="None", +def text_to_others( + parameter, + text="None", to_None=True, to_False=False, to_True=False, @@ -468,90 +468,8 @@ def check_same( if to_Float: parameter = float(parameter) - - return parameter - - -def normalize_positive_number( - parameter: Any, - var_name: str = "parameter", - convert_to: str = "float", - default_like_values=("auto", "none", ""), -): - """Normalize a value to a positive float/int or None. - - Parameters - ---------- - parameter : any - Value to normalize. Strings are matched against default-like tokens - before numeric conversion. - var_name : str, optional - Name used in log messages. - convert_to : str, optional - Target numeric type. Supported values are ``"float"`` and ``"int"``. - default_like_values : tuple of str, optional - Lowercased text tokens that should be treated as missing values. - - Returns - ------- - float or int or None - Positive converted value, or ``None`` when the input is default-like, - invalid, or non-positive. - """ - if parameter is None or isinstance(parameter, bool): - logger.info( - "%s=%r is treated as missing input. Falling back to automatic " - "behavior.", - var_name, - parameter, - ) - return None - - if isinstance(parameter, str): - parameter_str = parameter.strip().lower() - if parameter_str in default_like_values: - logger.info( - "%s=%r is treated as default-like input. Falling back to " - "automatic behavior.", - var_name, - parameter, - ) - return None - - try: - if convert_to == "float": - parameter = float(parameter) - elif convert_to == "int": - parameter = int(parameter) - else: - logger.warning( - "%s uses unsupported conversion '%s'. Falling back to " - "automatic behavior.", - var_name, - convert_to, - ) - return None - except (TypeError, ValueError): - logger.warning( - "Could not convert %s=%r to %s. Falling back to automatic " - "behavior.", - var_name, - parameter, - convert_to, - ) - return None - - if parameter <= 0: - logger.warning( - "%s=%r is not a positive %s. Falling back to automatic " - "behavior.", - var_name, - parameter, - convert_to, - ) - return None - - return parameter + + return parameter def annotation_category_relations( @@ -1356,4 +1274,4 @@ def compute_summary_qc_stats( "upper_mad", "lower_mad", "upper_quantile", "lower_quantile" ] - ) + ) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 97e4e513..669242b5 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -17,7 +17,6 @@ from spac.utils import check_label from spac.utils import get_defined_color_map from spac.utils import compute_boxplot_metrics -from spac.utils import normalize_positive_number from functools import partial from spac.utils import color_mapping, spell_out_special_characters from spac.data_utils import select_values diff --git a/tests/test_utils/test_normalize_positive_number.py b/tests/test_utils/test_normalize_positive_number.py deleted file mode 100644 index 2b3ac0df..00000000 --- a/tests/test_utils/test_normalize_positive_number.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest - -from spac.utils import normalize_positive_number - - -class TestNormalizePositiveNumber(unittest.TestCase): - def test_float_conversion(self): - self.assertEqual( - normalize_positive_number("11.5", convert_to="float"), - 11.5, - ) - - def test_int_conversion(self): - self.assertEqual( - normalize_positive_number("3", convert_to="int"), - 3, - ) - - def test_default_like_values_return_none(self): - self.assertIsNone(normalize_positive_number("auto", convert_to="int")) - self.assertIsNone(normalize_positive_number("None", convert_to="float")) - self.assertIsNone(normalize_positive_number(None, convert_to="float")) - - def test_invalid_or_non_positive_values_return_none(self): - self.assertIsNone(normalize_positive_number("bad", convert_to="int")) - self.assertIsNone(normalize_positive_number("-1", convert_to="float")) - self.assertIsNone(normalize_positive_number(0, convert_to="float")) - - def test_sanitized_inputs_are_logged(self): - with self.assertLogs("spac.utils", level="INFO") as logs: - self.assertIsNone( - normalize_positive_number( - "auto", - var_name="facet_ncol", - convert_to="int", - ) - ) - - self.assertTrue( - any("facet_ncol='auto'" in message for message in logs.output) - ) - - def test_invalid_inputs_are_logged_as_warning(self): - with self.assertLogs("spac.utils", level="WARNING") as logs: - self.assertIsNone( - normalize_positive_number( - "bad", - var_name="facet_fig_width", - convert_to="float", - ) - ) - - self.assertTrue( - any("facet_fig_width='bad'" in message for message in logs.output) - ) - - -if __name__ == "__main__": - unittest.main() From 367a9221a930ea754bdccfb9f64be2ae650fbf07 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 20 Apr 2026 00:25:05 -0400 Subject: [PATCH 41/57] fix(histogram_template): allow auto figure size in facet mode --- src/spac/templates/histogram_template.py | 25 ++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 5d9d2df6..effd2808 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -98,8 +98,8 @@ def run_from_json( layer = params.get("Table_", "Original") group_by = params.get("Group_by", "None") together = params.get("Together", True) - fig_width = params.get("Figure_Width", 8) - fig_height = params.get("Figure_Height", 6) + fig_width = params.get("Figure_Width", "auto") + fig_height = params.get("Figure_Height", "auto") font_size = params.get("Font_Size", 12) fig_dpi = params.get("Figure_DPI", 300) legend_location = params.get("Legend_Location", "best") @@ -218,21 +218,29 @@ def run_from_json( ) # validate figure size parameters + # If "auto" is specified, in facet mode it will be passed as None, + # allowing it to be computed based on facet layout hints automatically. + # In non-facet mode, it will default to 8x6 inches. fig_width = text_to_value( fig_width, + default_none_text="auto", + value_to_convert_to=None if facet else 8, to_float=True, param_name="Figure_Width" ) fig_height = text_to_value( fig_height, + default_none_text="auto", + value_to_convert_to=None if facet else 6, to_float=True, param_name="Figure_Height" ) - if fig_width <= 0 or fig_height <= 0: - raise ValueError( - f'Figure_Width/Height should be a positive number.' - f'Received "{fig_width}"/"{fig_height}".' - ) + if fig_width and fig_height: + if fig_width <= 0 or fig_height <= 0: + raise ValueError( + f'Figure_Width/Height should be a positive number.' + f'Received "{fig_width}"/"{fig_height}".' + ) # Validate x-axis label rotation if (x_rotate < 0) or (x_rotate > 360): @@ -307,7 +315,8 @@ def run_from_json( df_counts = result["df"] # Set figure size and dpi - fig.set_size_inches(fig_width, fig_height) + if fig_width and fig_height: + fig.set_size_inches(fig_width, fig_height) fig.set_dpi(fig_dpi) # Ensure axes is a list From c5bb404a6f68bf62882f34f462aa97b7d61c9c04 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 20 Apr 2026 00:28:19 -0400 Subject: [PATCH 42/57] feat(histogram): add long-label facet geometry adjustment - pass facet_tick_rotation from template to histogram - extend facet geometry helper with tick-length/rotation-based pressure heuristic - increase facet height and rebalance aspect when long rotated labels are present - keep explicit facet_fig_width/facet_fig_height authoritative - add focused histogram tests for rotation-zero parity and long-label geometry behavior --- src/spac/templates/histogram_template.py | 1 + src/spac/visualization.py | 81 +++++++++- tests/test_visualization/test_histogram.py | 168 +++++++++++++++++++-- 3 files changed, 240 insertions(+), 10 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index effd2808..0a117d67 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -308,6 +308,7 @@ def run_from_json( facet_ncol=facet_ncol, facet_fig_width=fig_width, facet_fig_height=fig_height, + facet_tick_rotation=x_rotate, ) fig = result["fig"] diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 669242b5..7d7ade6e 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -409,6 +409,8 @@ def _derive_facet_geometry( facet_ncol=None, facet_fig_width=None, facet_fig_height=None, + facet_tick_max_chars=0, + facet_tick_rotation=0.0, vertical_threshold=3, default_height=3.2, default_aspect=1.25, @@ -430,6 +432,14 @@ def _derive_facet_geometry( facet_fig_width, facet_fig_height : float, optional Optional total figure-size hints. Geometry is derived from these hints only when both values are present. + facet_tick_max_chars : int, optional + Maximum observed x tick-label length. Expected positive integer. + Used for adjusting default geometry heuristics when explicit figure-size + hints are absent. + ``0`` falls back to the original default geometry without long-label adjustments. + facet_tick_rotation : float, optional + Rotation angle in degrees for x tick labels. Used together with + ``facet_tick_max_chars`` to estimate label burden for default geometry. vertical_threshold : int, optional Maximum group count that still prefers a single-column automatic layout. @@ -458,7 +468,9 @@ def _derive_facet_geometry( normalized figure-size hints are present, the helper converts total figure size into per-panel geometry using the derived grid shape, applies the minimum panel-size guardrails, and clips aspect into the - configured range. + configured range. When figure-size hints are absent, the helper may + increase default facet height/aspect for long rotated categorical + labels to preserve usable bar area. """ # Derive facet_ncol when not explicitly provided, and clamp to n_groups @@ -488,6 +500,35 @@ def _derive_facet_geometry( facet_height = panel_height facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) + elif facet_tick_max_chars and facet_tick_max_chars > 0: + # For default geometry only, expand panel ratio when long rotated + # labels would otherwise dominate the available plotting area. + rotation = float(facet_tick_rotation or 0.0) % 360.0 + rad = np.deg2rad(min(rotation, 180.0)) + rotation_factor = 1.0 + 0.8 * np.sin(rad) + burden = float(facet_tick_max_chars) * rotation_factor + long_label_threshold = 12.0 + + if burden > long_label_threshold: + pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) + facet_height = default_height * (1.0 + 0.35 * pressure) + facet_aspect = float( + np.clip( + # default_aspect * (1.0 + 0.75 * pressure), + default_aspect * (1.0 - 0.05 * pressure), + min_aspect, + max_aspect, + ) + ) + logging.info( + "Automatic facet geometry adjustment for long x tick labels: " + "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", + facet_tick_max_chars, + facet_tick_rotation, + facet_height, + facet_aspect, + ) + return { "facet_ncol": facet_ncol, "facet_height": facet_height, @@ -589,6 +630,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, switches to a compact grid for many groups. - `facet_fig_width`: float, intended final figure width in inches. - `facet_fig_height`: float, intended final figure height in inches. + - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. Returns ------- @@ -769,6 +811,11 @@ def _is_default_like(value, default_like_values=None): "Both facet_fig_width and facet_fig_height must be provided together, " "or both must be left as None." ) + facet_tick_rotation = _parse_optional_number( + "facet_tick_rotation", + kwargs.pop('facet_tick_rotation', None), + to_default_value=0.0, + ) % 360.0 # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): @@ -839,6 +886,31 @@ def compute_global_bin_edges(data_series, bins): return np.histogram_bin_edges(data_series, bins=bins) return data_series.unique() + # Function to compute maximum tick label length for categorical data + def compute_max_tick_label_length(data_series): + """Compute maximum tick label length for a categorical data series. + + Parameters + ---------- + data_series : pandas.Series + Categorical data column used to compute maximum tick label length. + + Returns + ------- + int + Maximum number of characters in the tick labels derived from the + unique categories of the input series. + """ + if isinstance(data_series.dtype, pd.CategoricalDtype): + tick_labels = [ + str(label) for label in data_series.cat.categories + ] + else: + tick_labels = [ + str(label) for label in data_series.dropna().unique().tolist() + ] + return max((len(label) for label in tick_labels), default=0) + # Function to get axis labels based on log scale and stat parameters def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): """Resolve x/y axis labels for histogram rendering. @@ -942,6 +1014,11 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): axs.append(ax_i) else: # Facet option + # Compute max label length only when not explicitly provided. + facet_tick_max_chars = 0 + if not pd.api.types.is_numeric_dtype(plot_data[data_column]): + facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) + # Derive facet geometry based on group count and layout hints # Keys include: facet_ncol, facet_height, facet_aspect facet_layout = _derive_facet_geometry( @@ -949,6 +1026,8 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): facet_ncol=facet_ncol, facet_fig_width=facet_fig_width, facet_fig_height=facet_fig_height, + facet_tick_max_chars=facet_tick_max_chars, + facet_tick_rotation=facet_tick_rotation, ) # Compute global bins so all facets use consistent boundaries. diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index e64ed55b..b2f2cba3 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -627,10 +627,10 @@ def test_facet_ncol_layout_hints(self): with self.assertRaises(ValueError): histogram( self.adata, - feature='marker1', - group_by='annotation2', - facet=True, - facet_ncol='bad', + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol='bad', ) with self.assertRaises(ValueError): histogram( @@ -661,11 +661,11 @@ def test_facet_figure_size_hints(self): with self.assertRaises(ValueError): histogram( self.adata, - feature='marker1', - group_by='annotation2', - facet=True, - facet_fig_width=width, - facet_fig_height=height, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=width, + facet_fig_height=height, ) def test_facet_figure_size_hints_require_pair(self): @@ -687,6 +687,117 @@ def test_facet_figure_size_hints_require_pair(self): facet_fig_height=3.5, ) + def test_facet_tick_rotation_zero_matches_default_behavior(self): + """Explicit zero rotation should match omitted rotation behavior.""" + fig_default, _, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + fig_zero, _, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + facet_tick_rotation=0, + ).values() + + self.assertAlmostEqual(fig_default.get_figwidth(), fig_zero.get_figwidth(), places=6) + self.assertAlmostEqual(fig_default.get_figheight(), fig_zero.get_figheight(), places=6) + + def test_facet_long_label_geometry_adjustment_without_size_hints(self): + """Long rotated categorical labels should increase default facet geometry.""" + X = np.arange(1, 13, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + { + 'annotation_short': pd.Categorical( + ['A', 'B', 'C', 'D'] * 3, + categories=['A', 'B', 'C', 'D'], + ), + 'annotation_long': pd.Categorical( + [ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ] * 3, + categories=[ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ], + ), + 'annotation2': ['g1', 'g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(12)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + fig_short, _, _ = histogram( + adata, + annotation='annotation_short', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=45, + ).values() + fig_long, _, _ = histogram( + adata, + annotation='annotation_long', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=45, + ).values() + + self.assertGreater(fig_long.get_figwidth(), fig_short.get_figwidth()) + self.assertGreater(fig_long.get_figheight(), fig_short.get_figheight()) + + def test_facet_long_label_geometry_respects_explicit_size_hints(self): + """Explicit facet figure-size hints should remain authoritative.""" + X = np.arange(1, 13, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + { + 'annotation_long': pd.Categorical( + [ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ] * 3, + categories=[ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ], + ), + 'annotation2': ['g1', 'g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(12)], + ) + var = pd.DataFrame(index=['marker1']) + adata = anndata.AnnData(X, obs=obs, var=var) + + fig, _, _ = histogram( + adata, + annotation='annotation_long', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=60, + facet_fig_width=10, + facet_fig_height=4, + ).values() + + self.assertAlmostEqual(fig.get_figwidth(), 10.0, places=2) + self.assertAlmostEqual(fig.get_figheight(), 4.0, places=2) + def test_facet_plot_shared_bins_consistency_numeric(self): """Numeric facets keep shared bins for int/default-like bins inputs.""" # Unbalanced groups: each group occupies only part of the global range. @@ -809,5 +920,44 @@ def test_facet_plot_shared_bins_consistency_categorical(self): "Facet categorical y-ticks should remain shared across panels." ) + def test_facet_plot_categorical_annotation_ignores_bins(self): + """Facet categorical annotations should ignore custom bins values.""" + fig_small, axs_small, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + bins=2, + ).values() + fig_large, axs_large, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + bins=99, + ).values() + + axs_small = axs_small if isinstance(axs_small, (list, np.ndarray)) else [axs_small] + axs_large = axs_large if isinstance(axs_large, (list, np.ndarray)) else [axs_large] + + self.assertEqual(len(axs_small), len(axs_large)) + + axis_small = axs_small[0] + axis_large = axs_large[0] + small_centers = [ + patch.get_x() + patch.get_width() / 2 + for patch in axis_small.patches + ] + large_centers = [ + patch.get_x() + patch.get_width() / 2 + for patch in axis_large.patches + ] + self.assertEqual(small_centers, large_centers) + self.assertEqual( + [tick.get_text() for tick in axis_small.get_xticklabels()], + [tick.get_text() for tick in axis_large.get_xticklabels()], + ) + + if __name__ == '__main__': unittest.main() From 886fba29259929eb5ca778e3343d4ab586d81543 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 20 Apr 2026 21:18:46 -0400 Subject: [PATCH 43/57] fix(histogram): add max_groups guardrail This fix rejects grouping by annotations with too many categories - Add max_groups guadrail for group_by plotting - Add unittests covering max_groups guardrail --- src/spac/templates/histogram_template.py | 18 +++++- src/spac/visualization.py | 39 ++++++++++- tests/test_visualization/test_histogram.py | 75 ++++++++++++++++++++++ 3 files changed, 127 insertions(+), 5 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 0a117d67..2f30f6d0 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -96,7 +96,8 @@ def run_from_json( feature = text_to_value(params.get("Feature", "None")) annotation = text_to_value(params.get("Annotation", "None")) layer = params.get("Table_", "Original") - group_by = params.get("Group_by", "None") + group_by = text_to_value(params.get("Group_by", "None")) + max_groups = params.get("Max_Groups", 20) together = params.get("Together", True) fig_width = params.get("Figure_Width", "auto") fig_height = params.get("Figure_Height", "auto") @@ -249,6 +250,18 @@ def run_from_json( f'Received "{x_rotate}".' ) + # Validate max_groups parameter if group_by is specified + # max_groups can be a positive integer or "unlimited" (case-insensitive). + # If missing or None, it defaults to 20. + if group_by: + if max_groups != "unlimited": + max_groups = text_to_value( + max_groups, + value_to_convert_to=20, + to_int=True, + param_name="Max_Groups", + ) + # Validate facet, group_by, and together parameters for logical consistency if facet: if group_by is None: @@ -293,7 +306,7 @@ def run_from_json( feature=feature, annotation=annotation, layer=text_to_value(layer, "Original"), - group_by=text_to_value(group_by), + group_by=group_by, together=together, ax=None, x_log_scale=take_X_log, @@ -305,6 +318,7 @@ def run_from_json( bins=bins, alpha=alpha, stat=stat, + max_groups=max_groups, facet_ncol=facet_ncol, facet_fig_width=fig_width, facet_fig_height=fig_height, diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 7d7ade6e..26eb6482 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -26,7 +26,7 @@ import time import json import re -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional import matplotlib.colors as mcolors import matplotlib.patches as mpatch from functools import partial @@ -623,6 +623,13 @@ def histogram(adata, feature=None, annotation=None, layer=None, If not provided, or if passed as `None`/`"auto"`/`"none"`, the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. + + When `group_by` is provided, this optional key can be passed via `kwargs`: + - `max_groups`: positive int or "unlimited", maximum number of groups to plot. + Default is 20 when omitted. + Caution: `"unlimited"` disables this guardrail, but may lead to + performance issues or unreadable plots with many groups. + When `facet=True`, these optional key can be passed via `kwargs` to customize FacetGrid layout: - `facet_ncol`: positive int or "auto", number of facet columns. @@ -752,7 +759,8 @@ def _parse_optional_number( default_like_values=None, to_type="float", to_range=None, - to_default_value=None + to_default_value=None, + parse_rules : Optional[Dict[str, Union[int, float]]] = None, ): """Parse an optional numeric value with default-like string handling.""" def _is_default_like(value, default_like_values=None): @@ -762,6 +770,14 @@ def _is_default_like(value, default_like_values=None): if value is None or _is_default_like(value, default_like_values): return to_default_value + if parse_rules and isinstance(value, str): + parse_rules = {k.lower(): v for k, v in parse_rules.items()} + value_lower = value.strip().lower() + if value_lower in parse_rules: + logging.info( + f'Parsed {name}="{value}" as {parse_rules[value_lower]} ' + ) + return parse_rules[value_lower] try: if to_type == "float": parsed = float(value) @@ -769,7 +785,7 @@ def _is_default_like(value, default_like_values=None): parsed = int(value) except (TypeError, ValueError): raise ValueError( - f'{name} must be a positive {to_type}' + f'{name} must be a number of type {to_type}' f'{" or one of default values. " if default_like_values else ". "}' f'Received "{value}".' ) @@ -788,6 +804,15 @@ def _is_default_like(value, default_like_values=None): ) return parsed + # Parse max_groups with "unlimited" handling and validation. + max_groups = _parse_optional_number( + "max_groups", + kwargs.pop('max_groups', 20), + to_type="int", + to_default_value=20, + parse_rules={"unlimited": float('inf')}, + ) + # Parse facet layout hints so they never leak to seaborn. facet_ncol = _parse_optional_number( "facet_ncol", @@ -953,6 +978,14 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") + elif n_groups > max_groups: + raise ValueError( + "The number of groups in `group_by` exceeds `max_groups`: " + f"found {n_groups}, threshold {max_groups}.\n" + "Please reduce/bucket groups or use another grouping column, or " + "pass a larger `max_groups`.\n" + "See `kwargs` documentation for more details." + ) if together: if ax is None: diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index b2f2cba3..22db36a7 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -42,6 +42,16 @@ def tearDown(self): # Closes all figures to prevent memory issues plt.close('all') + def _make_many_groups_adata(self, n_groups=25): + """Create a compact AnnData fixture with one row per unique group.""" + X = np.arange(1, n_groups + 1, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + {'many_groups': [f'g{i}' for i in range(n_groups)]}, + index=[f'cell_{i}' for i in range(n_groups)], + ) + var = pd.DataFrame(index=['marker1']) + return anndata.AnnData(X, obs=obs, var=var) + def test_both_feature_and_annotation(self): err_msg = ("Cannot pass both feature and annotation," " choose one.") @@ -302,6 +312,71 @@ def test_group_by_separate_with_y_log_scale(self): self.assertEqual(ax.get_yscale(), 'log') self.assertEqual(ax.get_ylabel(), 'log(Count)') + def test_group_by_max_groups_default_guardrail_rejects_excess_groups(self): + """Default threshold should reject excessive grouped facet plotting.""" + adata = self._make_many_groups_adata(n_groups=25) + with self.assertRaisesRegex(ValueError, "exceeds `max_groups`"): + histogram( + adata, + feature='marker1', + group_by='many_groups', + facet=True, + ) + + def test_group_by_max_groups_override_allows_grouped_plot(self): + """Custom positive max_groups should allow larger grouped plots.""" + n_groups = 25 + adata = self._make_many_groups_adata(n_groups=n_groups) + + fig, axs, _ = histogram( + adata, + feature='marker1', + group_by='many_groups', + facet=True, + max_groups=30, + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + self.assertEqual(len(axs), n_groups) + + def test_group_by_max_groups_unlimited_disables_guardrail(self): + """max_groups='unlimited' should disable grouped guardrail validation.""" + adata = self._make_many_groups_adata(n_groups=25) + + fig, ax, _ = histogram( + adata, + feature='marker1', + group_by='many_groups', + together=True, + max_groups='unlimited', + ).values() + self.assertIsNotNone(fig) + self.assertIsInstance(ax, mpl.axes.Axes) + + def test_group_by_max_groups_none_uses_default_threshold(self): + """Explicit None should resolve to default threshold behavior.""" + adata = self._make_many_groups_adata(n_groups=25) + with self.assertRaisesRegex(ValueError, "exceeds `max_groups`"): + histogram( + adata, + feature='marker1', + group_by='many_groups', + together=True, + max_groups=None, + ) + + def test_group_by_invalid_max_groups_raises_value_error(self): + """Invalid max_groups values should fail fast.""" + for value in [0, "bad", True]: + with self.subTest(max_groups=value): + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + max_groups=value, + ) + def test_overlay_options(self): fig, ax, df = histogram( self.adata, From 406240210d2f20a4e38b959202b023d26d410b16 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 20 Apr 2026 23:26:14 -0400 Subject: [PATCH 44/57] fix(histogram-template): fix figure label overlapping by scaling facet layout - replace unconditional facet tight_layout with row-scaled spacing --- src/spac/templates/histogram_template.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 2f30f6d0..f44270ab 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -403,7 +403,22 @@ def run_from_json( for ax in axes: ax.set_title(f'Count plot of "{x_var}"') - plt.tight_layout() + # Adjust layout to prevent title overlap + if facet: + rows = len({round(ax.get_position().y0, 3) for ax in axes}) + fig.tight_layout( + rect=[ + max(0.02, 0.038 - 0.004 * rows), + max(0.022, 0.036 - 0.003 * rows), + min(0.992, 0.98 + 0.0025 * rows), + max(0.974, 0.98 - 0.001 * rows), + ], + pad=max(0.35, 0.6 - 0.05 * rows), + h_pad=max(0.2, 0.43 - 0.04 * rows), + w_pad=max(0.2, 0.43 - 0.04 * rows), + ) + else: + fig.tight_layout() logger.info("Displaying top 10 rows of histogram dataframe:") print(df_counts.head(10)) From ef31f5035d719e389b5fd44f2eba42f3d188c077 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Mon, 20 Apr 2026 23:53:10 -0400 Subject: [PATCH 45/57] refactor(histogram): simplify optional hint parsing - narrow `_parse_optional_number` to shared numeric parsing mechanics --- src/spac/visualization.py | 91 ++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 49 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 26eb6482..f567c03c 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -756,80 +756,73 @@ def cal_bin_num( def _parse_optional_number( name, value, - default_like_values=None, - to_type="float", - to_range=None, - to_default_value=None, - parse_rules : Optional[Dict[str, Union[int, float]]] = None, - ): - """Parse an optional numeric value with default-like string handling.""" - def _is_default_like(value, default_like_values=None): - if isinstance(value, str) and default_like_values is not None: - return value.strip().lower() in default_like_values - return False - - if value is None or _is_default_like(value, default_like_values): - return to_default_value - if parse_rules and isinstance(value, str): - parse_rules = {k.lower(): v for k, v in parse_rules.items()} - value_lower = value.strip().lower() - if value_lower in parse_rules: - logging.info( - f'Parsed {name}="{value}" as {parse_rules[value_lower]} ' - ) - return parse_rules[value_lower] + *, + kind=float, + default=None, + positive=False, + tokens=None, + ): + """Parse an optional numeric hint. + + Returns ``default`` for ``None``, resolves recognized string tokens + before numeric coercion, and optionally enforces finite and positive + values on the parsed result. + """ + if value is None: + return default + if isinstance(value, str): + value = value.strip() + if tokens and value.lower() in tokens: + return tokens[value.lower()] + expected = ( + f'{"positive " if positive else ""}{kind.__name__}' + f'{" or a supported keyword" if tokens else ""}' + ) + if isinstance(value, bool): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') try: - if to_type == "float": - parsed = float(value) - elif to_type == "int": - parsed = int(value) + parsed = kind(value) except (TypeError, ValueError): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + if not math.isfinite(parsed): raise ValueError( - f'{name} must be a number of type {to_type}' - f'{" or one of default values. " if default_like_values else ". "}' + f'{name} must be a finite {kind.__name__}. ' f'Received "{value}".' ) - if not math.isfinite(parsed): + if positive and parsed <= 0: raise ValueError( - f'{name} must be a finite {to_type}. Received "{value}".' + f'{name} must be a positive {kind.__name__}. ' + f'Received "{value}".' ) - if to_range == "positive": - to_range = [float('1e-10'), float('inf')] - if isinstance(to_range, list): - min_val, max_val = to_range - if parsed < min_val or parsed > max_val: - raise ValueError( - f'{name} must be a {to_type} in the range [{min_val}, {max_val}].' - f' Received "{value}".' - ) return parsed # Parse max_groups with "unlimited" handling and validation. max_groups = _parse_optional_number( "max_groups", - kwargs.pop('max_groups', 20), - to_type="int", - to_default_value=20, - parse_rules={"unlimited": float('inf')}, + kwargs.pop('max_groups', None), + kind=int, + default=20, + positive=True, + tokens={"unlimited": float('inf')}, ) # Parse facet layout hints so they never leak to seaborn. facet_ncol = _parse_optional_number( "facet_ncol", kwargs.pop('facet_ncol', None), - default_like_values={"", "auto", "none"}, - to_type="int", - to_range="positive", + kind=int, + positive=True, + tokens={"": None, "auto": None, "none": None}, ) facet_fig_width = _parse_optional_number( "facet_fig_width", kwargs.pop('facet_fig_width', None), - to_range="positive", + positive=True, ) facet_fig_height = _parse_optional_number( "facet_fig_height", kwargs.pop('facet_fig_height', None), - to_range="positive", + positive=True, ) if (facet_fig_width is None) != (facet_fig_height is None): raise ValueError( @@ -839,7 +832,7 @@ def _is_default_like(value, default_like_values=None): facet_tick_rotation = _parse_optional_number( "facet_tick_rotation", kwargs.pop('facet_tick_rotation', None), - to_default_value=0.0, + default=0.0, ) % 360.0 # Function to calculate histogram data From 44ce09f373581f491e8ba8cf5e052c0861a0767a Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 21 Apr 2026 02:21:47 -0400 Subject: [PATCH 46/57] fix(histogram): ignore multiple outside overlays - drop template-side allow-list validation for seaborn passthroughs - pass `multiple` only for grouped same-axis overlays - ignore irrelevant `multiple` in grouped non-overlay histogram paths - add regression coverage for grouped separate mode --- src/spac/templates/histogram_template.py | 76 ++++++++-------------- src/spac/visualization.py | 27 ++++---- tests/test_visualization/test_histogram.py | 16 +++++ 3 files changed, 58 insertions(+), 61 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index f44270ab..0cf2bd82 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -155,7 +155,9 @@ def run_from_json( 'No features available in adata.var_names to plot.' ) - # Validate and set bins + # Bins use a strict template contract in feature mode: + # "auto" or a positive integer. Loose null-like values are intentionally + # not treated as aliases here. if feature is not None: bins = text_to_value( bins, @@ -184,44 +186,13 @@ def run_from_json( "Setting bin number calculation to auto." ) - # Validate multiple parameter based on together - if together is False and multiple: - multiple = "dodge" - logger.warning( - "Multiple should not be used when Together is False. " - "Setting Multiple to 'dodge'." - ) - - # Validate enum-like plotting controls after bins validation. - allowed_multiple = {"layer", "dodge", "stack", "fill"} - allowed_element = {"bars", "step", "poly"} - allowed_stat = { - "count", "frequency", "density", "probability", - "proportion", "percent" - } multiple = str(multiple).strip().lower() element = str(element).strip().lower() stat = str(stat).strip().lower() - if multiple not in allowed_multiple: - raise ValueError( - f'Multiple must be one of {sorted(allowed_multiple)}. ' - f'Received "{multiple}".' - ) - if element not in allowed_element: - raise ValueError( - f'Element must be one of {sorted(allowed_element)}. ' - f'Received "{element}".' - ) - if stat not in allowed_stat: - raise ValueError( - f'Stat must be one of {sorted(allowed_stat)}. ' - f'Received "{stat}".' - ) - # validate figure size parameters - # If "auto" is specified, in facet mode it will be passed as None, - # allowing it to be computed based on facet layout hints automatically. - # In non-facet mode, it will default to 8x6 inches. + # Figure size hints use the explicit template token "auto". In facet mode + # it is forwarded as None so core geometry can derive the final figure + # size; in non-facet mode it falls back to 8x6 inches. fig_width = text_to_value( fig_width, default_none_text="auto", @@ -250,9 +221,8 @@ def run_from_json( f'Received "{x_rotate}".' ) - # Validate max_groups parameter if group_by is specified - # max_groups can be a positive integer or "unlimited" (case-insensitive). - # If missing or None, it defaults to 20. + # max_groups uses a strict template token contract: positive integer or + # the exact keyword "unlimited". Missing values keep the default of 20. if group_by: if max_groups != "unlimited": max_groups = text_to_value( @@ -274,7 +244,7 @@ def run_from_json( 'Together and Facet cannot both be True. Please set one to False.' ) - # Validate and canonicalize facet_ncol, allowing for "auto" or positive integers + # facet_ncol uses the explicit template token "auto", or a positive int. facet_ncol = text_to_value( facet_ncol, default_none_text="auto", @@ -301,6 +271,22 @@ def run_from_json( # In facet mode, Figure_Width/Height are passed as layout hints so # visualization can derive panel geometry from total figure size: # panel_width = Figure_Width / ncol, panel_height = Figure_Height / nrow. + hist_kwargs = dict( + element=element, + shrink=shrink, + bins=bins, + alpha=alpha, + stat=stat, + max_groups=max_groups, + facet_ncol=facet_ncol, + facet_fig_width=fig_width, + facet_fig_height=fig_height, + facet_tick_rotation=x_rotate, + ) + # 'multiple' is only applicable when plotting multiple groups together + if group_by and together: + hist_kwargs["multiple"] = multiple + result = histogram( adata=adata, feature=feature, @@ -312,17 +298,7 @@ def run_from_json( x_log_scale=take_X_log, y_log_scale=take_Y_log, facet=facet, - multiple=multiple, - element=element, - shrink=shrink, - bins=bins, - alpha=alpha, - stat=stat, - max_groups=max_groups, - facet_ncol=facet_ncol, - facet_fig_width=fig_width, - facet_fig_height=fig_height, - facet_tick_rotation=x_rotate, + **hist_kwargs, ) fig = result["fig"] diff --git a/src/spac/visualization.py b/src/spac/visualization.py index f567c03c..f162ab79 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -564,7 +564,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, If True, and if group_by != None, create one plot combining all groups. If False, create separate histograms for each group. The appearance of combined histograms can be controlled using the - `multiple` and `element` parameters in **kwargs. + `multiple` and `element` parameters in **kwargs. Separate grouped or + faceted histograms ignore `multiple`. To control how the histograms are normalized (e.g., to divide the histogram by the number of elements in every group), use the `stat` parameter in **kwargs. For example, set `stat="probability"` to show @@ -589,7 +590,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, Additional keyword arguments passed to seaborn histplot function. Key arguments include: - `multiple`: Determines how the subsets of data are displayed - on the same axes. Options include: + on the same axes. Ignored when `group_by` is used with + `together=False`. Options include: * "layer": Draws each subset on top of the other without adjustments. * "dodge": Dodges bars for each subset side by side. @@ -761,15 +763,15 @@ def _parse_optional_number( default=None, positive=False, tokens=None, - ): - """Parse an optional numeric hint. - - Returns ``default`` for ``None``, resolves recognized string tokens - before numeric coercion, and optionally enforces finite and positive - values on the parsed result. - """ - if value is None: - return default + ): + """Parse an optional numeric hint. + + Returns ``default`` for ``None``, resolves recognized string tokens + before numeric coercion, and optionally enforces finite and positive + values on the parsed result. + """ + if value is None: + return default if isinstance(value, str): value = value.strip() if tokens and value.lower() in tokens: @@ -1013,6 +1015,9 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): axs.append(ax) else: + # 'multiple' parameter is not applicable + kwargs.pop('multiple', None) + if not facet: fig, ax_array = plt.subplots( n_groups, 1, figsize=(5, 5 * n_groups) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 22db36a7..85187f48 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -547,6 +547,22 @@ def test_default_like_bins_calculation(self): self.assertEqual(len(ax.patches), expected_bins) + def test_grouped_separate_ignores_multiple(self): + """Grouped separate mode should ignore irrelevant multiple settings.""" + fig, axs, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=False, + multiple="fill", + ).values() + self.assertIsInstance(fig, mpl.figure.Figure) + self.assertIsInstance(axs, list) + self.assertEqual( + len(axs), + self.adata.obs["annotation2"].dropna().nunique(), + ) + def test_facet_requires_group_by(self): """Test that facet mode requires group_by parameter""" with self.assertRaisesRegex( From 0198eece2cf847f9dfdf5fcca70ebd73f6bf2c6e Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 21 Apr 2026 11:34:18 -0400 Subject: [PATCH 47/57] test(histogram): improve facet unittests - add a thin facet smoke test for numeric annotations from adshareata.obs - extract a shared long-label AnnData builder for paired geometry tests - improve d-bin regression setups inline for readability --- tests/test_visualization/test_histogram.py | 142 +++++++++++---------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 85187f48..47b27713 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -52,6 +52,38 @@ def _make_many_groups_adata(self, n_groups=25): var = pd.DataFrame(index=['marker1']) return anndata.AnnData(X, obs=obs, var=var) + def _make_long_label_facet_adata(self, include_short=False): + """Create small categorical facet fixtures for long-label geometry tests.""" + obs = { + 'annotation2': ['g1', 'g1', 'g1', 'g1', + 'g2', 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3', 'g3'], + } + if include_short: + obs['annotation_short'] = pd.Categorical( + ['A', 'B', 'C', 'D'] * 3, + categories=['A', 'B', 'C', 'D'], + ) + obs['annotation_long'] = pd.Categorical( + [ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ] * 3, + categories=[ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ], + ) + return anndata.AnnData( + np.arange(1, 13, dtype=np.float32).reshape(-1, 1), + obs=pd.DataFrame(obs, index=[f'cell_{i}' for i in range(12)]), + var=pd.DataFrame(index=['marker1']), + ) + def test_both_feature_and_annotation(self): err_msg = ("Cannot pass both feature and annotation," " choose one.") @@ -688,6 +720,34 @@ def test_facet_plot_categorical_annotation(self): self.assertEqual(fig._supxlabel.get_text(), 'annotation1') self.assertEqual(fig._supylabel.get_text(), 'Count') + def test_facet_plot_numeric_annotation(self): + """Facet mode should support numeric annotations sourced from obs.""" + adata = self.adata.copy() + adata.obs['annotation_numeric'] = np.arange( + adata.n_obs, + dtype=np.float32, + ) + + fig, axs, _ = histogram( + adata, + annotation='annotation_numeric', + group_by='annotation2', + facet=True, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + expected_groups = adata.obs['annotation2'].dropna().nunique() + self.assertIsNotNone(fig) + self.assertEqual(len(axs), expected_groups) + + for axis in axs: + self.assertGreater(len(axis.patches), 0) + + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'annotation_numeric') + self.assertEqual(fig._supylabel.get_text(), 'Count') + def test_facet_ncol_layout_hints(self): """Facet ncol supports positive int and documented auto behavior.""" # Explicit two-column layout should create two facet columns. @@ -799,34 +859,7 @@ def test_facet_tick_rotation_zero_matches_default_behavior(self): def test_facet_long_label_geometry_adjustment_without_size_hints(self): """Long rotated categorical labels should increase default facet geometry.""" - X = np.arange(1, 13, dtype=np.float32).reshape(-1, 1) - obs = pd.DataFrame( - { - 'annotation_short': pd.Categorical( - ['A', 'B', 'C', 'D'] * 3, - categories=['A', 'B', 'C', 'D'], - ), - 'annotation_long': pd.Categorical( - [ - 'Activated T/B Cell', - 'Cytotoxic T Cell', - 'Follicular Dendritic Cell', - 'Regulatory T Cell', - ] * 3, - categories=[ - 'Activated T/B Cell', - 'Cytotoxic T Cell', - 'Follicular Dendritic Cell', - 'Regulatory T Cell', - ], - ), - 'annotation2': ['g1', 'g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g2', - 'g3', 'g3', 'g3', 'g3'], - }, - index=[f'cell_{i}' for i in range(12)], - ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) + adata = self._make_long_label_facet_adata(include_short=True) fig_short, _, _ = histogram( adata, @@ -850,30 +883,7 @@ def test_facet_long_label_geometry_adjustment_without_size_hints(self): def test_facet_long_label_geometry_respects_explicit_size_hints(self): """Explicit facet figure-size hints should remain authoritative.""" - X = np.arange(1, 13, dtype=np.float32).reshape(-1, 1) - obs = pd.DataFrame( - { - 'annotation_long': pd.Categorical( - [ - 'Activated T/B Cell', - 'Cytotoxic T Cell', - 'Follicular Dendritic Cell', - 'Regulatory T Cell', - ] * 3, - categories=[ - 'Activated T/B Cell', - 'Cytotoxic T Cell', - 'Follicular Dendritic Cell', - 'Regulatory T Cell', - ], - ), - 'annotation2': ['g1', 'g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g2', - 'g3', 'g3', 'g3', 'g3'], - }, - index=[f'cell_{i}' for i in range(12)], - ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) + adata = self._make_long_label_facet_adata() fig, _, _ = histogram( adata, @@ -893,16 +903,15 @@ def test_facet_plot_shared_bins_consistency_numeric(self): """Numeric facets keep shared bins for int/default-like bins inputs.""" # Unbalanced groups: each group occupies only part of the global range. # If bins are computed per-group (bad path), centers/ticks may diverge. - X = np.array([[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]], - dtype=np.float32) - obs = pd.DataFrame( - { - 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2'], - }, + adata = anndata.AnnData( + np.array([[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]], + dtype=np.float32), + obs=pd.DataFrame( + {'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2']}, index=[f'cell_{i}' for i in range(6)], + ), + var=pd.DataFrame(index=['marker1']), ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) # Test one explicit and one default-like bins path. for bins_value in [4, None]: @@ -958,19 +967,22 @@ def test_facet_plot_shared_bins_consistency_numeric(self): def test_facet_plot_shared_bins_consistency_categorical(self): """Facet categorical bins stay aligned even with missing labels.""" # Build unbalanced facet groups where some labels are missing per group. - X = np.arange(1, 10, dtype=np.float32).reshape(-1, 1) - obs = pd.DataFrame( + adata = anndata.AnnData( + np.arange(1, 10, dtype=np.float32).reshape(-1, 1), + obs=pd.DataFrame( { 'annotation1': pd.Categorical( ['A', 'A', 'B', 'A', 'C', 'C', 'B', 'C', 'A'], categories=['A', 'B', 'C'], ), - 'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2', 'g3', 'g3', 'g3'], + 'annotation2': ['g1', 'g1', 'g1', + 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3'], }, index=[f'cell_{i}' for i in range(9)], + ), + var=pd.DataFrame(index=['marker1']), ) - var = pd.DataFrame(index=['marker1']) - adata = anndata.AnnData(X, obs=obs, var=var) fig, axs, df = histogram( adata, From 860c7a120492c5a115ac4265f9702d023d1df2d2 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 21 Apr 2026 15:30:21 -0400 Subject: [PATCH 48/57] fix(histogram): finalize grouped shared-bin handling - reuse grouped histogram-bin tables across together and facet paths - keep categorical shared-slot padding internal to grouped histogram building - restrict Rice-rule auto-bin fallback to numeric data - fold return-data checks into existing histogram tests --- src/spac/visualization.py | 123 +++++++++++++-------- tests/test_visualization/test_histogram.py | 47 ++++++++ 2 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index f162ab79..0c1a624c 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -886,25 +886,49 @@ def calculate_histogram(data, bins, bin_edges=None): 'count': counts.values }) - # Function to compute shared bin edges for grouped histograms - def compute_global_bin_edges(data_series, bins): - """Compute shared bin boundaries for grouped histogram paths. + def build_grouped_histogram_table( + plot_data, data_column, group_by, groups, bins + ): + """Build per-group histogram-bin tables for grouped histogram paths.""" + # Determine shared bins across groups for consistent plotting. + data_series = plot_data[data_column] + if pd.api.types.is_numeric_dtype(data_series): + shared_bins = np.histogram_bin_edges(data_series, bins=bins) + elif isinstance(data_series.dtype, pd.CategoricalDtype): + shared_bins = data_series.cat.categories + else: + shared_bins = pd.Index(pd.unique(data_series.dropna())) - Parameters - ---------- - data_series : pandas.Series - Data column used to derive shared bins. - bins : int or sequence - Bin specification forwarded to numpy/seaborn logic. + # Compute histograms for each group using the shared bins. + histograms = [] + for group in groups: + group_data = plot_data.loc[ + plot_data[group_by] == group, data_column + ] + if pd.api.types.is_numeric_dtype(data_series): + group_hist = calculate_histogram( + group_data, bins, bin_edges=shared_bins + ) + else: + # For categorical data, pad missing categories with zero counts + # to ensure consistent plotting. + group_hist = calculate_histogram(group_data, bins) + group_hist = ( + group_hist + .set_index('bin_center') + .reindex(shared_bins) + .rename_axis('bin_center') + .reset_index() + ) + group_hist['count'] = group_hist['count'].fillna(0) + group_hist['bin_left'] = group_hist['bin_center'] + group_hist['bin_right'] = group_hist['bin_center'] + group_hist[group_by] = group + histograms.append(group_hist) - Returns - ------- - numpy.ndarray or pandas.Index - Numeric bin edges, or categorical labels for non-numeric data. - """ - if pd.api.types.is_numeric_dtype(data_series): - return np.histogram_bin_edges(data_series, bins=bins) - return data_series.unique() + # Concatenate all group histograms into a single DataFrame for plotting. + hist_data = pd.concat(histograms, ignore_index=True) + return hist_data, shared_bins # Function to compute maximum tick label length for categorical data def compute_max_tick_label_length(data_series): @@ -986,29 +1010,28 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): if ax is None: fig, ax = plt.subplots() - # Compute global bin edges based on the entire dataset - global_bin_edges = compute_global_bin_edges( - plot_data[data_column], kwargs['bins'] + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), ) - - hist_data = [] - # Compute histograms for each group separately and combine them - for group in groups: - group_data = plot_data[ - plot_data[group_by] == group - ][data_column] - group_hist = calculate_histogram(group_data, kwargs['bins'], - bin_edges=global_bin_edges) - group_hist[group_by] = group - hist_data.append(group_hist) - hist_data = pd.concat(hist_data, ignore_index=True) + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + kwargs['bins'] = shared_bins.tolist() # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - sns.histplot(data=hist_data, x='bin_center', weights='count', - hue=group_by, ax=ax, **kwargs) + sns.histplot( + data=hist_data, + x='bin_center', + weights='count', + hue=group_by, + ax=ax, + **kwargs, + ) # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') @@ -1051,7 +1074,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) # Derive facet geometry based on group count and layout hints - # Keys include: facet_ncol, facet_height, facet_aspect + # Returned layout keys: facet_ncol, facet_height, facet_aspect facet_layout = _derive_facet_geometry( n_groups=n_groups, facet_ncol=facet_ncol, @@ -1061,14 +1084,21 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): facet_tick_rotation=facet_tick_rotation, ) - # Compute global bins so all facets use consistent boundaries. - global_bin_edges = compute_global_bin_edges( - plot_data[data_column], kwargs['bins'] + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), ) + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + kwargs['bins'] = shared_bins.tolist() # Create the FacetGrid for the histogram hist = sns.FacetGrid( - plot_data, + hist_data, col=group_by, col_wrap=facet_layout['facet_ncol'], height=facet_layout['facet_height'], @@ -1077,15 +1107,13 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): sharey=True, ) - hist_kwargs = kwargs.copy() - # For numeric data, pass global bin edges to ensure consistent binning across facets. - if pd.api.types.is_numeric_dtype(plot_data[data_column]): - hist_kwargs['bins'] = global_bin_edges.tolist() - else: - hist_kwargs.pop('bins', None) - # Map the histogram function to the grid - hist.map_dataframe(sns.histplot, x=data_column, **hist_kwargs) + hist.map_dataframe( + sns.histplot, + x='bin_center', + weights='count', + **kwargs, + ) # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column for ax_i in hist.axes.flat: @@ -1114,7 +1142,6 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): facet_fig_height or fig.get_figheight(), ) axs.extend(hist.axes.flat) - hist_data = plot_data else: if ax is None: diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 47b27713..11f72856 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -564,6 +564,11 @@ def test_default_bins_calculation(self): # No bins parameter passed fig, ax, df = histogram(self.adata, feature='marker1').values() self.assertEqual(len(ax.patches), expected_bins) + self.assertEqual(len(df), expected_bins) + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center'}, + ) def test_default_like_bins_calculation(self): """Default-like bins values should use Rice-rule fallback.""" @@ -578,6 +583,11 @@ def test_default_like_bins_calculation(self): ).values() self.assertEqual(len(ax.patches), expected_bins) + self.assertEqual(len(df), expected_bins) + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center'}, + ) def test_grouped_separate_ignores_multiple(self): """Grouped separate mode should ignore irrelevant multiple settings.""" @@ -963,6 +973,25 @@ def test_facet_plot_shared_bins_consistency_numeric(self): np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), "Facet numeric y-ticks should remain shared across panels." ) + + # Check that the returned DataFrame has expected structure and content + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center', 'annotation2'}, + ) + self.assertNotIn('marker1', df.columns) + self.assertEqual(set(df['annotation2']), {'g1', 'g2'}) + self.assertEqual(df['count'].sum(), adata.n_obs) + grouped_edges = [ + ( + np.round(group_df['bin_left'].to_numpy(), 6), + np.round(group_df['bin_right'].to_numpy(), 6), + ) + for _, group_df in df.groupby('annotation2') + ] + self.assertEqual(len(grouped_edges), 2) + self.assertTrue(np.array_equal(grouped_edges[0][0], grouped_edges[1][0])) + self.assertTrue(np.array_equal(grouped_edges[0][1], grouped_edges[1][1])) def test_facet_plot_shared_bins_consistency_categorical(self): """Facet categorical bins stay aligned even with missing labels.""" @@ -1022,6 +1051,24 @@ def test_facet_plot_shared_bins_consistency_categorical(self): np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), "Facet categorical y-ticks should remain shared across panels." ) + + # Check that the returned DataFrame has expected structure and content + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center', 'annotation2'}, + ) + self.assertNotIn('annotation1', df.columns) + self.assertEqual(set(df['annotation2']), {'g1', 'g2', 'g3'}) + self.assertEqual(df['count'].sum(), adata.n_obs) + self.assertEqual( + {str(value) for value in df['bin_center'].unique()}, + {'A', 'B', 'C'}, + ) + for _, group_df in df.groupby('annotation2'): + self.assertEqual( + {str(value) for value in group_df['bin_center']}, + {'A', 'B', 'C'}, + ) def test_facet_plot_categorical_annotation_ignores_bins(self): """Facet categorical annotations should ignore custom bins values.""" From 6e6c14a9332d4c2147c86195aeb34a9cfe13ac8f Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 22 Apr 2026 01:11:26 -0400 Subject: [PATCH 49/57] fix(histogram): ignore facet hints outside facets - update documentation for `max_groups` and `facet_ncol` - parse facet figure-size hints only when `facet=True` - add a compact non-facet unittest for ignored hints --- src/spac/visualization.py | 51 +++++++++++++--------- tests/test_visualization/test_histogram.py | 33 ++++++++++++++ 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0c1a624c..e0815274 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -627,16 +627,18 @@ def histogram(adata, feature=None, annotation=None, layer=None, Note, don't pass a numpy array, only python lists or strs/numbers. When `group_by` is provided, this optional key can be passed via `kwargs`: - - `max_groups`: positive int or "unlimited", maximum number of groups to plot. - Default is 20 when omitted. - Caution: `"unlimited"` disables this guardrail, but may lead to - performance issues or unreadable plots with many groups. + - `max_groups`: Controls the group-count guardrail for grouped plots. + Default is 20 when omitted. Pass `"unlimited"` to disable this + guardrail, which may lead to performance issues or unreadable plots + with many groups. When `facet=True`, these optional key can be passed via `kwargs` to customize FacetGrid layout: - - `facet_ncol`: positive int or "auto", number of facet columns. - If "auto", the function uses one column for small group counts and - switches to a compact grid for many groups. + - `facet_ncol`: Controls facet column wrapping. + If omitted or passed as `"auto"`, the function uses one column for + small group counts and switches to a compact grid for many groups. + Otherwise, the provided value is used to request the facet column + count. - `facet_fig_width`: float, intended final figure width in inches. - `facet_fig_height`: float, intended final figure height in inches. - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. @@ -816,21 +818,28 @@ def _parse_optional_number( positive=True, tokens={"": None, "auto": None, "none": None}, ) - facet_fig_width = _parse_optional_number( - "facet_fig_width", - kwargs.pop('facet_fig_width', None), - positive=True, - ) - facet_fig_height = _parse_optional_number( - "facet_fig_height", - kwargs.pop('facet_fig_height', None), - positive=True, - ) - if (facet_fig_width is None) != (facet_fig_height is None): - raise ValueError( - "Both facet_fig_width and facet_fig_height must be provided together, " - "or both must be left as None." + facet_fig_width = kwargs.pop('facet_fig_width', None) + facet_fig_height = kwargs.pop('facet_fig_height', None) + if facet: + facet_fig_width = _parse_optional_number( + "facet_fig_width", + facet_fig_width, + positive=True, + ) + facet_fig_height = _parse_optional_number( + "facet_fig_height", + facet_fig_height, + positive=True, ) + if (facet_fig_width is None) != (facet_fig_height is None): + raise ValueError( + "Both facet_fig_width and facet_fig_height must be provided together, " + "or both must be left as None." + ) + else: + # If not faceting, ignore any provided figure size hints. + facet_fig_width = None + facet_fig_height = None facet_tick_rotation = _parse_optional_number( "facet_tick_rotation", kwargs.pop('facet_tick_rotation', None), diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 11f72856..7f3ad365 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -848,6 +848,39 @@ def test_facet_figure_size_hints_require_pair(self): facet_fig_height=3.5, ) + def test_non_facet_figure_size_hints_are_ignored(self): + """Non-facet calls should ignore facet-only figure-size hints.""" + baseline_fig, baseline_ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + ).values() + + for hint_kwargs in ( + {'facet_fig_width': 8}, + {'facet_fig_height': 5}, + {'facet_fig_width': 8, 'facet_fig_height': 5}, + ): + with self.subTest(hints=hint_kwargs): + fig, ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + **hint_kwargs, + ).values() + self.assertAlmostEqual( + fig.get_figwidth(), + baseline_fig.get_figwidth(), + places=6, + ) + self.assertAlmostEqual( + fig.get_figheight(), + baseline_fig.get_figheight(), + places=6, + ) + self.assertGreater(len(ax.patches), 0) + self.assertEqual(len(ax.patches), len(baseline_ax.patches)) + def test_facet_tick_rotation_zero_matches_default_behavior(self): """Explicit zero rotation should match omitted rotation behavior.""" fig_default, _, _ = histogram( From 4b6bd12af37c87fe10bc279a7088cc5b679807fb Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 22 Apr 2026 01:25:10 -0400 Subject: [PATCH 50/57] style(histogram): clean whitespace in histogram diff --- src/spac/templates/histogram_template.py | 4 +- src/spac/utils.py | 6 +- src/spac/visualization.py | 2030 ++++++++++---------- tests/test_visualization/test_histogram.py | 4 +- 4 files changed, 1022 insertions(+), 1022 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 0cf2bd82..50122c78 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -233,7 +233,7 @@ def run_from_json( ) # Validate facet, group_by, and together parameters for logical consistency - if facet: + if facet: if group_by is None: raise ValueError( 'Facet is True but Group_by is not specified. ' @@ -243,7 +243,7 @@ def run_from_json( raise ValueError( 'Together and Facet cannot both be True. Please set one to False.' ) - + # facet_ncol uses the explicit template token "auto", or a positive int. facet_ncol = text_to_value( facet_ncol, diff --git a/src/spac/utils.py b/src/spac/utils.py index 08d89ee6..5c26c284 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -8,12 +8,12 @@ import warnings import numbers from scipy.stats import median_abs_deviation -from typing import Any, List, Optional +from typing import Any, List, Optional # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) def regex_search_list( @@ -1274,4 +1274,4 @@ def compute_summary_qc_stats( "upper_mad", "lower_mad", "upper_quantile", "lower_quantile" ] - ) + ) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index e0815274..abfea309 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -26,17 +26,17 @@ import time import json import re -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional import matplotlib.colors as mcolors import matplotlib.patches as mpatch from functools import partial from collections import OrderedDict - + # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - + def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, @@ -46,7 +46,7 @@ def visualize_2D_scatter( ): """ Visualize 2D data using plt.scatter. - + Parameters ---------- x, y : array-like @@ -74,7 +74,7 @@ def visualize_2D_scatter( Description of what the colors represent. **kwargs Additional keyword arguments passed to plt.scatter. - + Returns ------- fig : matplotlib.figure.Figure @@ -82,7 +82,7 @@ def visualize_2D_scatter( ax : matplotlib.axes.Axes The axes of the plot. """ - + # Input validation if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): raise ValueError("x and y must be array-like.") @@ -90,7 +90,7 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - + # Define color themes themes = { 'fire': plt.get_cmap('inferno'), @@ -103,34 +103,34 @@ def visualize_2D_scatter( 'darkred': ListedColormap(['#8B0000']), 'darkgreen': ListedColormap(['#006400']) } - + if theme and theme not in themes: error_msg = ( f"Theme '{theme}' not recognized. Please use a valid theme." ) raise ValueError(error_msg) cmap = themes.get(theme, plt.get_cmap('viridis')) - + # Determine point size num_points = len(x) if point_size is None: point_size = 5000 / num_points - + # Get figure size and fontsize from kwargs or set defaults fig_width = kwargs.get('fig_width', 10) fig_height = kwargs.get('fig_height', 8) fontsize = kwargs.get('fontsize', 12) - + if ax is None: fig, ax = plt.subplots(figsize=(fig_width, fig_height)) else: fig = ax.figure - + # Plotting logic if labels is not None: # Check if labels are categorical if pd.api.types.is_categorical_dtype(labels): - + # Determine how to access the categories based on # the type of 'labels' if isinstance(labels, pd.Series): @@ -142,16 +142,16 @@ def visualize_2D_scatter( "Expected labels to be of type Series[Categorical] or " "Categorical." ) - + # Combine colors from multiple colormaps cmap1 = plt.get_cmap('tab20') cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors - + # Use the number of unique clusters to set the colormap length cmap = ListedColormap(colors[:len(unique_clusters)]) - + for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster ax.scatter( @@ -161,7 +161,7 @@ def visualize_2D_scatter( s=point_size ) print(f"Cluster: {cluster}, Points: {np.sum(mask)}") - + if annotate_centers: center_x = np.mean(x[mask]) center_y = np.mean(y[mask]) @@ -175,7 +175,7 @@ def visualize_2D_scatter( bbox_to_anchor=(1.25, 1), # Adjusting position title=f"Color represents: {color_representation}" ) - + else: # If labels are continuous scatter = ax.scatter( @@ -188,20 +188,20 @@ def visualize_2D_scatter( ) else: scatter = ax.scatter(x, y, c='gray', s=point_size, **kwargs) - + # Equal aspect ratio for the axes ax.set_aspect('equal', 'datalim') - + # Set axis labels ax.set_xlabel(x_axis_title) ax.set_ylabel(y_axis_title) - + # Set plot title if plot_title is not None: ax.set_title(plot_title) - + return fig, ax - + def dimensionality_reduction_plot( adata, @@ -214,7 +214,7 @@ def dimensionality_reduction_plot( **kwargs): """ Visualize scatter plot in PCA, t-SNE, UMAP, or associated table. - + Parameters ---------- adata : anndata.AnnData @@ -242,7 +242,7 @@ def dimensionality_reduction_plot( **kwargs Parameters passed to visualize_2D_scatter function, including point_size. - + Returns ------- fig : matplotlib.figure.Figure @@ -250,13 +250,13 @@ def dimensionality_reduction_plot( ax : matplotlib.axes.Axes The axes of the plot. """ - + # Check if both annotation and feature are specified, raise error if so if annotation and feature: raise ValueError( "Please specify either an annotation or a feature for coloring, " "not both.") - + # Use utility functions for input validation if layer: check_table(adata, tables=layer) @@ -264,21 +264,21 @@ def dimensionality_reduction_plot( check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=[feature]) - + # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca'] if method not in valid_methods: raise ValueError("Method should be one of {'tsne', 'umap', 'pca'}" f'. Got:"{method}"') - + key = f'X_{method}' if key not in adata.obsm.keys(): raise ValueError( f"{key} coordinates not found in adata.obsm. " f"Please run {method.upper()} before calling this function." ) - + else: check_table( adata=adata, @@ -286,7 +286,7 @@ def dimensionality_reduction_plot( should_exist=True, associated_table=True ) - + associated_table_shape = adata.obsm[associated_table].shape if associated_table_shape[1] != 2: raise ValueError( @@ -294,12 +294,12 @@ def dimensionality_reduction_plot( f' two dimensions. It shape is:"{associated_table_shape}"' ) key = associated_table - + print(f'Running visualization using the coordinates: "{key}"') - + # Extract the 2D coordinates x, y = adata.obsm[key].T - + # Determine coloring scheme if annotation: color_values = adata.obs[annotation].astype('category').values @@ -311,7 +311,7 @@ def dimensionality_reduction_plot( else: color_values = None color_representation = None - + # Set axis titles based on method and color representation if method == 'tsne': x_axis_title = 't-SNE 1' @@ -329,13 +329,13 @@ def dimensionality_reduction_plot( x_axis_title = f'{associated_table} 1' y_axis_title = f'{associated_table} 2' plot_title = f'{associated_table}-{color_representation}' - + # Remove conflicting keys from kwargs kwargs.pop('x_axis_title', None) kwargs.pop('y_axis_title', None) kwargs.pop('plot_title', None) kwargs.pop('color_representation', None) - + fig, ax = visualize_2D_scatter( x=x, y=y, @@ -347,14 +347,14 @@ def dimensionality_reduction_plot( color_representation=color_representation, **kwargs ) - + return fig, ax - + def tsne_plot(adata, color_column=None, ax=None, **kwargs): """ Visualize scatter plot in tSNE basis. - + Parameters ---------- adata : anndata.AnnData @@ -367,7 +367,7 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): If not provided, a new figure and axes will be created. **kwargs Parameters passed to scanpy.pl.tsne function. - + Returns ------- fig : matplotlib.figure.Figure @@ -377,221 +377,221 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): """ if not isinstance(adata, anndata.AnnData): raise ValueError("adata must be an AnnData object.") - + if 'X_tsne' not in adata.obsm: err_msg = ("adata.obsm does not contain 'X_tsne', " "perform t-SNE transformation first.") raise ValueError(err_msg) - + # Create a new figure and axes if not provided if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() - + if color_column and (color_column not in adata.obs.columns and color_column not in adata.var.columns): err_msg = f"'{color_column}' not found in adata.obs or adata.var." raise KeyError(err_msg) - + # Add color column to the kwargs for the scanpy plot if color_column: kwargs['color'] = color_column - + # Plot the t-SNE sc.pl.tsne(adata, ax=ax, **kwargs) - + return fig, ax - - -def _derive_facet_geometry( - n_groups, - facet_ncol=None, - facet_fig_width=None, - facet_fig_height=None, - facet_tick_max_chars=0, - facet_tick_rotation=0.0, - vertical_threshold=3, - default_height=3.2, - default_aspect=1.25, - min_panel_width=1.8, - min_panel_height=1.6, - min_aspect=0.6, - max_aspect=2.0, -): - """Derive FacetGrid geometry from pre-normalized facet layout hints. - - Parameters - ---------- - n_groups : int - Number of facet panels. Expected to be a positive integer supplied by - the grouped histogram path. - facet_ncol : int or None, optional - Requested facet column count. Positive integers are used directly. - ``None`` falls back to automatic column selection. - facet_fig_width, facet_fig_height : float, optional - Optional total figure-size hints. Geometry is derived from these hints only - when both values are present. - facet_tick_max_chars : int, optional - Maximum observed x tick-label length. Expected positive integer. - Used for adjusting default geometry heuristics when explicit figure-size - hints are absent. - ``0`` falls back to the original default geometry without long-label adjustments. - facet_tick_rotation : float, optional - Rotation angle in degrees for x tick labels. Used together with - ``facet_tick_max_chars`` to estimate label burden for default geometry. - vertical_threshold : int, optional - Maximum group count that still prefers a single-column automatic - layout. - default_height : float, optional - Per-facet panel height used when explicit figure-size hints are not - available. - default_aspect : float, optional - Per-facet panel aspect ratio used when explicit figure-size hints are - not available. - min_panel_width, min_panel_height : float, optional - Lower bounds applied to per-panel dimensions when figure-size hints are - converted into FacetGrid geometry. - min_aspect, max_aspect : float, optional - Bounds applied to the derived panel aspect ratio. - - Returns - ------- - dict - Dictionary with keys: - - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; - - ``facet_height``: float, FacetGrid-ready per-panel height in inches; - - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. - - Automatic layout uses one column when ``n_groups <= vertical_threshold`` - and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both - normalized figure-size hints are present, the helper converts total - figure size into per-panel geometry using the derived grid shape, - applies the minimum panel-size guardrails, and clips aspect into the - configured range. When figure-size hints are absent, the helper may - increase default facet height/aspect for long rotated categorical - labels to preserve usable bar area. - """ - - # Derive facet_ncol when not explicitly provided, and clamp to n_groups - if facet_ncol is None: - if n_groups <= vertical_threshold: - facet_ncol = 1 - else: - facet_ncol = int(np.ceil(np.sqrt(n_groups))) - logging.info( - "Automatic facet_ncol selection: %s columns for %s groups " - "(vertical_threshold=%s).", - facet_ncol, - n_groups, - vertical_threshold, - ) - facet_ncol = max(1, min(int(facet_ncol), n_groups)) - - # Use defaults if figure-size hints are not provided - facet_height = default_height - facet_aspect = default_aspect - - # Derive facet geometry from figure-size hints when both are provided - if facet_fig_width is not None and facet_fig_height is not None: - nrow = int(np.ceil(n_groups / facet_ncol)) - panel_width = max(facet_fig_width / facet_ncol, min_panel_width) - panel_height = max(facet_fig_height / nrow, min_panel_height) - facet_height = panel_height - facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) - - elif facet_tick_max_chars and facet_tick_max_chars > 0: - # For default geometry only, expand panel ratio when long rotated - # labels would otherwise dominate the available plotting area. - rotation = float(facet_tick_rotation or 0.0) % 360.0 - rad = np.deg2rad(min(rotation, 180.0)) - rotation_factor = 1.0 + 0.8 * np.sin(rad) - burden = float(facet_tick_max_chars) * rotation_factor - long_label_threshold = 12.0 - - if burden > long_label_threshold: - pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) - facet_height = default_height * (1.0 + 0.35 * pressure) - facet_aspect = float( - np.clip( - # default_aspect * (1.0 + 0.75 * pressure), - default_aspect * (1.0 - 0.05 * pressure), - min_aspect, - max_aspect, - ) - ) - logging.info( - "Automatic facet geometry adjustment for long x tick labels: " - "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", - facet_tick_max_chars, - facet_tick_rotation, - facet_height, - facet_aspect, - ) - - return { - "facet_ncol": facet_ncol, - "facet_height": facet_height, - "facet_aspect": facet_aspect, - } - - + + +def _derive_facet_geometry( + n_groups, + facet_ncol=None, + facet_fig_width=None, + facet_fig_height=None, + facet_tick_max_chars=0, + facet_tick_rotation=0.0, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + min_panel_width=1.8, + min_panel_height=1.6, + min_aspect=0.6, + max_aspect=2.0, +): + """Derive FacetGrid geometry from pre-normalized facet layout hints. + + Parameters + ---------- + n_groups : int + Number of facet panels. Expected to be a positive integer supplied by + the grouped histogram path. + facet_ncol : int or None, optional + Requested facet column count. Positive integers are used directly. + ``None`` falls back to automatic column selection. + facet_fig_width, facet_fig_height : float, optional + Optional total figure-size hints. Geometry is derived from these hints only + when both values are present. + facet_tick_max_chars : int, optional + Maximum observed x tick-label length. Expected positive integer. + Used for adjusting default geometry heuristics when explicit figure-size + hints are absent. + ``0`` falls back to the original default geometry without long-label adjustments. + facet_tick_rotation : float, optional + Rotation angle in degrees for x tick labels. Used together with + ``facet_tick_max_chars`` to estimate label burden for default geometry. + vertical_threshold : int, optional + Maximum group count that still prefers a single-column automatic + layout. + default_height : float, optional + Per-facet panel height used when explicit figure-size hints are not + available. + default_aspect : float, optional + Per-facet panel aspect ratio used when explicit figure-size hints are + not available. + min_panel_width, min_panel_height : float, optional + Lower bounds applied to per-panel dimensions when figure-size hints are + converted into FacetGrid geometry. + min_aspect, max_aspect : float, optional + Bounds applied to the derived panel aspect ratio. + + Returns + ------- + dict + Dictionary with keys: + - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; + - ``facet_height``: float, FacetGrid-ready per-panel height in inches; + - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. + + Automatic layout uses one column when ``n_groups <= vertical_threshold`` + and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both + normalized figure-size hints are present, the helper converts total + figure size into per-panel geometry using the derived grid shape, + applies the minimum panel-size guardrails, and clips aspect into the + configured range. When figure-size hints are absent, the helper may + increase default facet height/aspect for long rotated categorical + labels to preserve usable bar area. + """ + + # Derive facet_ncol when not explicitly provided, and clamp to n_groups + if facet_ncol is None: + if n_groups <= vertical_threshold: + facet_ncol = 1 + else: + facet_ncol = int(np.ceil(np.sqrt(n_groups))) + logging.info( + "Automatic facet_ncol selection: %s columns for %s groups " + "(vertical_threshold=%s).", + facet_ncol, + n_groups, + vertical_threshold, + ) + facet_ncol = max(1, min(int(facet_ncol), n_groups)) + + # Use defaults if figure-size hints are not provided + facet_height = default_height + facet_aspect = default_aspect + + # Derive facet geometry from figure-size hints when both are provided + if facet_fig_width is not None and facet_fig_height is not None: + nrow = int(np.ceil(n_groups / facet_ncol)) + panel_width = max(facet_fig_width / facet_ncol, min_panel_width) + panel_height = max(facet_fig_height / nrow, min_panel_height) + facet_height = panel_height + facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) + + elif facet_tick_max_chars and facet_tick_max_chars > 0: + # For default geometry only, expand panel ratio when long rotated + # labels would otherwise dominate the available plotting area. + rotation = float(facet_tick_rotation or 0.0) % 360.0 + rad = np.deg2rad(min(rotation, 180.0)) + rotation_factor = 1.0 + 0.8 * np.sin(rad) + burden = float(facet_tick_max_chars) * rotation_factor + long_label_threshold = 12.0 + + if burden > long_label_threshold: + pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) + facet_height = default_height * (1.0 + 0.35 * pressure) + facet_aspect = float( + np.clip( + # default_aspect * (1.0 + 0.75 * pressure), + default_aspect * (1.0 - 0.05 * pressure), + min_aspect, + max_aspect, + ) + ) + logging.info( + "Automatic facet geometry adjustment for long x tick labels: " + "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", + facet_tick_max_chars, + facet_tick_rotation, + facet_height, + facet_aspect, + ) + + return { + "facet_ncol": facet_ncol, + "facet_height": facet_height, + "facet_aspect": facet_aspect, + } + + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, - x_log_scale=False, y_log_scale=False, facet=False, **kwargs): + x_log_scale=False, y_log_scale=False, facet=False, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + feature : str, optional Name of continuous feature from adata.X to plot its histogram. - + annotation : str, optional Name of the annotation from adata.obs to plot its histogram. - + layer : str, optional Name of the layer in adata.layers to plot its histogram. - + group_by : str, default None Choose either to group the histogram by another column. - + together : bool, default False If True, and if group_by != None, create one plot combining all groups. If False, create separate histograms for each group. The appearance of combined histograms can be controlled using the - `multiple` and `element` parameters in **kwargs. Separate grouped or - faceted histograms ignore `multiple`. + `multiple` and `element` parameters in **kwargs. Separate grouped or + faceted histograms ignore `multiple`. To control how the histograms are normalized (e.g., to divide the histogram by the number of elements in every group), use the `stat` parameter in **kwargs. For example, set `stat="probability"` to show the relative frequencies of each group. - + ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. - Not supported for grouped-separate (`group_by` with `together=False`) - or facet layouts (`group_by` with `facet=True`). - + Not supported for grouped-separate (`group_by` with `together=False`) + or facet layouts (`group_by` with `facet=True`). + x_log_scale : bool, default False If True, the data will be transformed using np.log1p before plotting, and the x-axis label will be adjusted accordingly. - + y_log_scale : bool, default False If True, the y-axis will be set to log scale. - - facet : bool, default False - If True, group by function outputs facet plots - + + facet : bool, default False + If True, group by function outputs facet plots + **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: - `multiple`: Determines how the subsets of data are displayed - on the same axes. Ignored when `group_by` is used with - `together=False`. Options include: + on the same axes. Ignored when `group_by` is used with + `together=False`. Options include: * "layer": Draws each subset on top of the other without adjustments. * "dodge": Dodges bars for each subset side by side. @@ -603,10 +603,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, * "step": Creates a step line plot without bars. * "poly": Creates a polygon where the bottom edge represents the x-axis and the top edge the histogram's bins. - - `shrink`: Scale bar width relative to each bin's width. - Useful for reducing overlap with `multiple="dodge"`. - - `alpha`: Transparency for histogram artists. - 0 is fully transparent and 1 is fully opaque. + - `shrink`: Scale bar width relative to each bin's width. + Useful for reducing overlap with `multiple="dodge"`. + - `alpha`: Transparency for histogram artists. + 0 is fully transparent and 1 is fully opaque. - `log_scale`: Determines if the data should be plotted on a logarithmic scale. - `stat`: Determines the statistical transformation to use on the data @@ -622,42 +622,42 @@ def histogram(adata, feature=None, annotation=None, layer=None, Can be a number (indicating the number of bins) or a list (indicating bin edges). For example, `bins=10` will create 10 bins, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. - If not provided, or if passed as `None`/`"auto"`/`"none"`, - the binning will be determined automatically using the Rice rule. + If not provided, or if passed as `None`/`"auto"`/`"none"`, + the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. - - When `group_by` is provided, this optional key can be passed via `kwargs`: - - `max_groups`: Controls the group-count guardrail for grouped plots. - Default is 20 when omitted. Pass `"unlimited"` to disable this - guardrail, which may lead to performance issues or unreadable plots - with many groups. - - When `facet=True`, these optional key can be passed via `kwargs` - to customize FacetGrid layout: - - `facet_ncol`: Controls facet column wrapping. - If omitted or passed as `"auto"`, the function uses one column for - small group counts and switches to a compact grid for many groups. - Otherwise, the provided value is used to request the facet column - count. - - `facet_fig_width`: float, intended final figure width in inches. - - `facet_fig_height`: float, intended final figure height in inches. - - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. - + + When `group_by` is provided, this optional key can be passed via `kwargs`: + - `max_groups`: Controls the group-count guardrail for grouped plots. + Default is 20 when omitted. Pass `"unlimited"` to disable this + guardrail, which may lead to performance issues or unreadable plots + with many groups. + + When `facet=True`, these optional key can be passed via `kwargs` + to customize FacetGrid layout: + - `facet_ncol`: Controls facet column wrapping. + If omitted or passed as `"auto"`, the function uses one column for + small group counts and switches to a compact grid for many groups. + Otherwise, the provided value is used to request the facet column + count. + - `facet_fig_width`: float, intended final figure width in inches. + - `facet_fig_height`: float, intended final figure height in inches. + - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. + Returns ------- A dictionary containing the following: fig : matplotlib.figure.Figure The created figure for the plot. - + axs : matplotlib.axes.Axes or list of Axes The Axes object(s) of the histogram plot(s). Returns a single Axes if only one plot is created, otherwise returns a list of Axes. - + df : pandas.DataFrame DataFrame containing the data used for plotting the histogram. - + """ - + # If no feature or annotation is specified, apply default behavior if feature is None and annotation is None: # Default to the first feature in adata.var_names @@ -668,7 +668,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, f"'{feature}'.", UserWarning ) - + # Use utility functions for input validation if layer: check_table(adata, tables=layer) @@ -678,7 +678,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, check_feature(adata, features=feature) if group_by: check_annotation(adata, annotations=group_by) - + # If layer is specified, get the data from that layer if layer: df = pd.DataFrame( @@ -689,15 +689,15 @@ def histogram(adata, feature=None, annotation=None, layer=None, adata.X, index=adata.obs.index, columns=adata.var_names ) layer = 'Original' - + df = pd.concat([df, adata.obs], axis=1) - + if feature and annotation: raise ValueError("Cannot pass both feature and annotation," " choose one.") - + data_column = feature if feature else annotation - + # Check for negative values and apply log1p transformation if # x_log_scale is True if x_log_scale: @@ -708,156 +708,156 @@ def histogram(adata, feature=None, annotation=None, layer=None, x_log_scale = False else: df[data_column] = np.log1p(df[data_column]) - - # If ax is provided, validate input and get figure from it. - # If not, the figure will be created in the plotting branch. + + # If ax is provided, validate input and get figure from it. + # If not, the figure will be created in the plotting branch. if ax is not None: - if group_by and not together: - raise ValueError( - "External ax is only supported for single-axes histogram " - "Please set together=True or remove external ax." - ) + if group_by and not together: + raise ValueError( + "External ax is only supported for single-axes histogram " + "Please set together=True or remove external ax." + ) fig = ax.get_figure() - + axs = [] - + # Prepare the data for plotting plot_data = df.dropna(subset=[data_column]) - + # Bin calculation section # The default bin calculation used by sns.histo take quite # some time to compute for large number of points, # DMAP implemented the Rice rule for bin computation - + def cal_bin_num( num_rows ): bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') - - return (bins) - + + return (bins) + num_rows = plot_data.shape[0] - - # Check if bins is not being passed or set to None or "auto" in kwargs. - # If so, the in house algorithm will compute the number of bins - bins_kwarg = kwargs.get('bins', None) - if isinstance(bins_kwarg, str): - bins_kwarg = bins_kwarg.strip().lower() - if bins_kwarg in {'', 'auto', 'none'}: - bins_kwarg = None - if bins_kwarg is None: + + # Check if bins is not being passed or set to None or "auto" in kwargs. + # If so, the in house algorithm will compute the number of bins + bins_kwarg = kwargs.get('bins', None) + if isinstance(bins_kwarg, str): + bins_kwarg = bins_kwarg.strip().lower() + if bins_kwarg in {'', 'auto', 'none'}: + bins_kwarg = None + if bins_kwarg is None: kwargs['bins'] = cal_bin_num(num_rows) - - # Input validation for facet - if facet: - if group_by is None: - raise ValueError("group_by must be specified when facet=True.") - if together: - raise ValueError("Cannot use together=True with facet=True," - " choose one.") - - def _parse_optional_number( - name, - value, - *, - kind=float, - default=None, - positive=False, - tokens=None, - ): - """Parse an optional numeric hint. - - Returns ``default`` for ``None``, resolves recognized string tokens - before numeric coercion, and optionally enforces finite and positive - values on the parsed result. - """ - if value is None: - return default - if isinstance(value, str): - value = value.strip() - if tokens and value.lower() in tokens: - return tokens[value.lower()] - expected = ( - f'{"positive " if positive else ""}{kind.__name__}' - f'{" or a supported keyword" if tokens else ""}' - ) - if isinstance(value, bool): - raise ValueError(f'{name} must be a {expected}. Received "{value}".') - try: - parsed = kind(value) - except (TypeError, ValueError): - raise ValueError(f'{name} must be a {expected}. Received "{value}".') - if not math.isfinite(parsed): - raise ValueError( - f'{name} must be a finite {kind.__name__}. ' - f'Received "{value}".' - ) - if positive and parsed <= 0: - raise ValueError( - f'{name} must be a positive {kind.__name__}. ' - f'Received "{value}".' - ) - return parsed - - # Parse max_groups with "unlimited" handling and validation. - max_groups = _parse_optional_number( - "max_groups", - kwargs.pop('max_groups', None), - kind=int, - default=20, - positive=True, - tokens={"unlimited": float('inf')}, - ) - - # Parse facet layout hints so they never leak to seaborn. - facet_ncol = _parse_optional_number( - "facet_ncol", - kwargs.pop('facet_ncol', None), - kind=int, - positive=True, - tokens={"": None, "auto": None, "none": None}, - ) - facet_fig_width = kwargs.pop('facet_fig_width', None) - facet_fig_height = kwargs.pop('facet_fig_height', None) - if facet: - facet_fig_width = _parse_optional_number( - "facet_fig_width", - facet_fig_width, - positive=True, - ) - facet_fig_height = _parse_optional_number( - "facet_fig_height", - facet_fig_height, - positive=True, - ) - if (facet_fig_width is None) != (facet_fig_height is None): - raise ValueError( - "Both facet_fig_width and facet_fig_height must be provided together, " - "or both must be left as None." - ) - else: - # If not faceting, ignore any provided figure size hints. - facet_fig_width = None - facet_fig_height = None - facet_tick_rotation = _parse_optional_number( - "facet_tick_rotation", - kwargs.pop('facet_tick_rotation', None), - default=0.0, - ) % 360.0 - + + # Input validation for facet + if facet: + if group_by is None: + raise ValueError("group_by must be specified when facet=True.") + if together: + raise ValueError("Cannot use together=True with facet=True," + " choose one.") + + def _parse_optional_number( + name, + value, + *, + kind=float, + default=None, + positive=False, + tokens=None, + ): + """Parse an optional numeric hint. + + Returns ``default`` for ``None``, resolves recognized string tokens + before numeric coercion, and optionally enforces finite and positive + values on the parsed result. + """ + if value is None: + return default + if isinstance(value, str): + value = value.strip() + if tokens and value.lower() in tokens: + return tokens[value.lower()] + expected = ( + f'{"positive " if positive else ""}{kind.__name__}' + f'{" or a supported keyword" if tokens else ""}' + ) + if isinstance(value, bool): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + try: + parsed = kind(value) + except (TypeError, ValueError): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + if not math.isfinite(parsed): + raise ValueError( + f'{name} must be a finite {kind.__name__}. ' + f'Received "{value}".' + ) + if positive and parsed <= 0: + raise ValueError( + f'{name} must be a positive {kind.__name__}. ' + f'Received "{value}".' + ) + return parsed + + # Parse max_groups with "unlimited" handling and validation. + max_groups = _parse_optional_number( + "max_groups", + kwargs.pop('max_groups', None), + kind=int, + default=20, + positive=True, + tokens={"unlimited": float('inf')}, + ) + + # Parse facet layout hints so they never leak to seaborn. + facet_ncol = _parse_optional_number( + "facet_ncol", + kwargs.pop('facet_ncol', None), + kind=int, + positive=True, + tokens={"": None, "auto": None, "none": None}, + ) + facet_fig_width = kwargs.pop('facet_fig_width', None) + facet_fig_height = kwargs.pop('facet_fig_height', None) + if facet: + facet_fig_width = _parse_optional_number( + "facet_fig_width", + facet_fig_width, + positive=True, + ) + facet_fig_height = _parse_optional_number( + "facet_fig_height", + facet_fig_height, + positive=True, + ) + if (facet_fig_width is None) != (facet_fig_height is None): + raise ValueError( + "Both facet_fig_width and facet_fig_height must be provided together, " + "or both must be left as None." + ) + else: + # If not faceting, ignore any provided figure size hints. + facet_fig_width = None + facet_fig_height = None + facet_tick_rotation = _parse_optional_number( + "facet_tick_rotation", + kwargs.pop('facet_tick_rotation', None), + default=0.0, + ) % 360.0 + # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): """ Compute histogram data for numeric or categorical input. - + Parameters: - data (pd.Series): The input data to be binned. - bins (int or sequence): Number of bins (if numeric) or unique categories (if categorical). - bin_edges (array-like, optional): Predefined bin edges for numeric data. If None, automatic binning is used. - + Returns: - pd.DataFrame: A DataFrame containing the following columns: - `count`: @@ -869,9 +869,9 @@ def calculate_histogram(data, bins, bin_edges=None): - `bin_center`: Center of each bin (for numeric data) or category labels (for categorical data). - + """ - + # Check if the data is numeric or categorical if pd.api.types.is_numeric_dtype(data): if bin_edges is None: @@ -894,274 +894,274 @@ def calculate_histogram(data, bins, bin_edges=None): 'bin_right': counts.index, 'count': counts.values }) - - def build_grouped_histogram_table( - plot_data, data_column, group_by, groups, bins - ): - """Build per-group histogram-bin tables for grouped histogram paths.""" - # Determine shared bins across groups for consistent plotting. - data_series = plot_data[data_column] - if pd.api.types.is_numeric_dtype(data_series): - shared_bins = np.histogram_bin_edges(data_series, bins=bins) - elif isinstance(data_series.dtype, pd.CategoricalDtype): - shared_bins = data_series.cat.categories - else: - shared_bins = pd.Index(pd.unique(data_series.dropna())) - - # Compute histograms for each group using the shared bins. - histograms = [] - for group in groups: - group_data = plot_data.loc[ - plot_data[group_by] == group, data_column - ] - if pd.api.types.is_numeric_dtype(data_series): - group_hist = calculate_histogram( - group_data, bins, bin_edges=shared_bins - ) - else: - # For categorical data, pad missing categories with zero counts - # to ensure consistent plotting. - group_hist = calculate_histogram(group_data, bins) - group_hist = ( - group_hist - .set_index('bin_center') - .reindex(shared_bins) - .rename_axis('bin_center') - .reset_index() - ) - group_hist['count'] = group_hist['count'].fillna(0) - group_hist['bin_left'] = group_hist['bin_center'] - group_hist['bin_right'] = group_hist['bin_center'] - group_hist[group_by] = group - histograms.append(group_hist) - - # Concatenate all group histograms into a single DataFrame for plotting. - hist_data = pd.concat(histograms, ignore_index=True) - return hist_data, shared_bins - - # Function to compute maximum tick label length for categorical data - def compute_max_tick_label_length(data_series): - """Compute maximum tick label length for a categorical data series. - - Parameters - ---------- - data_series : pandas.Series - Categorical data column used to compute maximum tick label length. - - Returns - ------- - int - Maximum number of characters in the tick labels derived from the - unique categories of the input series. - """ - if isinstance(data_series.dtype, pd.CategoricalDtype): - tick_labels = [ - str(label) for label in data_series.cat.categories - ] - else: - tick_labels = [ - str(label) for label in data_series.dropna().unique().tolist() - ] - return max((len(label) for label in tick_labels), default=0) - - # Function to get axis labels based on log scale and stat parameters - def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): - """Resolve x/y axis labels for histogram rendering. - - Parameters - ---------- - data_column : str - Source column used on the x axis. - x_log_scale : bool - Whether x data has log transform semantics. - y_log_scale : bool - Whether y axis is displayed on log scale. - stat : str - Histogram statistic mode (for example, count, density). - - Returns - ------- - tuple[str, str] - Resolved x-axis and y-axis labels. - """ - xlabel = f'log({data_column})' if x_log_scale else data_column - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability', - "proportion": "Proportion", - "percent": "Percent" - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - return xlabel, ylabel - + + def build_grouped_histogram_table( + plot_data, data_column, group_by, groups, bins + ): + """Build per-group histogram-bin tables for grouped histogram paths.""" + # Determine shared bins across groups for consistent plotting. + data_series = plot_data[data_column] + if pd.api.types.is_numeric_dtype(data_series): + shared_bins = np.histogram_bin_edges(data_series, bins=bins) + elif isinstance(data_series.dtype, pd.CategoricalDtype): + shared_bins = data_series.cat.categories + else: + shared_bins = pd.Index(pd.unique(data_series.dropna())) + + # Compute histograms for each group using the shared bins. + histograms = [] + for group in groups: + group_data = plot_data.loc[ + plot_data[group_by] == group, data_column + ] + if pd.api.types.is_numeric_dtype(data_series): + group_hist = calculate_histogram( + group_data, bins, bin_edges=shared_bins + ) + else: + # For categorical data, pad missing categories with zero counts + # to ensure consistent plotting. + group_hist = calculate_histogram(group_data, bins) + group_hist = ( + group_hist + .set_index('bin_center') + .reindex(shared_bins) + .rename_axis('bin_center') + .reset_index() + ) + group_hist['count'] = group_hist['count'].fillna(0) + group_hist['bin_left'] = group_hist['bin_center'] + group_hist['bin_right'] = group_hist['bin_center'] + group_hist[group_by] = group + histograms.append(group_hist) + + # Concatenate all group histograms into a single DataFrame for plotting. + hist_data = pd.concat(histograms, ignore_index=True) + return hist_data, shared_bins + + # Function to compute maximum tick label length for categorical data + def compute_max_tick_label_length(data_series): + """Compute maximum tick label length for a categorical data series. + + Parameters + ---------- + data_series : pandas.Series + Categorical data column used to compute maximum tick label length. + + Returns + ------- + int + Maximum number of characters in the tick labels derived from the + unique categories of the input series. + """ + if isinstance(data_series.dtype, pd.CategoricalDtype): + tick_labels = [ + str(label) for label in data_series.cat.categories + ] + else: + tick_labels = [ + str(label) for label in data_series.dropna().unique().tolist() + ] + return max((len(label) for label in tick_labels), default=0) + + # Function to get axis labels based on log scale and stat parameters + def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): + """Resolve x/y axis labels for histogram rendering. + + Parameters + ---------- + data_column : str + Source column used on the x axis. + x_log_scale : bool + Whether x data has log transform semantics. + y_log_scale : bool + Whether y axis is displayed on log scale. + stat : str + Histogram statistic mode (for example, count, density). + + Returns + ------- + tuple[str, str] + Resolved x-axis and y-axis labels. + """ + xlabel = f'log({data_column})' if x_log_scale else data_column + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability', + "proportion": "Proportion", + "percent": "Percent" + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + return xlabel, ylabel + # Plotting with or without grouping if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) - + if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") - elif n_groups > max_groups: - raise ValueError( - "The number of groups in `group_by` exceeds `max_groups`: " - f"found {n_groups}, threshold {max_groups}.\n" - "Please reduce/bucket groups or use another grouping column, or " - "pass a larger `max_groups`.\n" - "See `kwargs` documentation for more details." - ) - + elif n_groups > max_groups: + raise ValueError( + "The number of groups in `group_by` exceeds `max_groups`: " + f"found {n_groups}, threshold {max_groups}.\n" + "Please reduce/bucket groups or use another grouping column, or " + "pass a larger `max_groups`.\n" + "See `kwargs` documentation for more details." + ) + if together: - if ax is None: - fig, ax = plt.subplots() - - hist_data, shared_bins = build_grouped_histogram_table( - plot_data, - data_column, - group_by, - groups, - bins=kwargs.pop('bins'), - ) + if ax is None: + fig, ax = plt.subplots() + + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) if pd.api.types.is_numeric_dtype(plot_data[data_column]): - kwargs['bins'] = shared_bins.tolist() - + kwargs['bins'] = shared_bins.tolist() + # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - - sns.histplot( - data=hist_data, - x='bin_center', - weights='count', - hue=group_by, - ax=ax, - **kwargs, - ) + + sns.histplot( + data=hist_data, + x='bin_center', + weights='count', + hue=group_by, + ax=ax, + **kwargs, + ) # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) - + else: - # 'multiple' parameter is not applicable - kwargs.pop('multiple', None) - - if not facet: - fig, ax_array = plt.subplots( - n_groups, 1, figsize=(5, 5 * n_groups) - ) - - # Convert a single Axes object to a list - # Ensure ax_array is always iterable - if n_groups == 1: - ax_array = [ax_array] + # 'multiple' parameter is not applicable + kwargs.pop('multiple', None) + + if not facet: + fig, ax_array = plt.subplots( + n_groups, 1, figsize=(5, 5 * n_groups) + ) + + # Convert a single Axes object to a list + # Ensure ax_array is always iterable + if n_groups == 1: + ax_array = [ax_array] else: - ax_array = ax_array.flatten() - - for i, ax_i in enumerate(ax_array): - group_data = plot_data[plot_data[group_by] == - groups[i]][data_column] - hist_data = calculate_histogram(group_data, kwargs['bins']) - - sns.histplot(data=hist_data, x="bin_center", ax=ax_i, - weights='count', **kwargs) - # If plotting feature specify which layer - if feature: - ax_i.set_title(f'{groups[i]} with Layer: {layer}') - else: - ax_i.set_title(f'{groups[i]}') - axs.append(ax_i) - - else: # Facet option - # Compute max label length only when not explicitly provided. - facet_tick_max_chars = 0 - if not pd.api.types.is_numeric_dtype(plot_data[data_column]): - facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) - - # Derive facet geometry based on group count and layout hints - # Returned layout keys: facet_ncol, facet_height, facet_aspect - facet_layout = _derive_facet_geometry( - n_groups=n_groups, - facet_ncol=facet_ncol, - facet_fig_width=facet_fig_width, - facet_fig_height=facet_fig_height, - facet_tick_max_chars=facet_tick_max_chars, - facet_tick_rotation=facet_tick_rotation, - ) - - # Compute histogram data and shared bins for consistent plotting. - # For non-numeric data, shared_bins will be dropped intentionally. - hist_data, shared_bins = build_grouped_histogram_table( - plot_data, - data_column, - group_by, - groups, - bins=kwargs.pop('bins'), - ) - if pd.api.types.is_numeric_dtype(plot_data[data_column]): - kwargs['bins'] = shared_bins.tolist() - - # Create the FacetGrid for the histogram - hist = sns.FacetGrid( - hist_data, - col=group_by, - col_wrap=facet_layout['facet_ncol'], - height=facet_layout['facet_height'], - aspect=facet_layout['facet_aspect'], - sharex=True, - sharey=True, - ) - - # Map the histogram function to the grid - hist.map_dataframe( - sns.histplot, - x='bin_center', - weights='count', - **kwargs, - ) - - # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column - for ax_i in hist.axes.flat: - ax_i.tick_params(axis='x', labelbottom=True) - ax_i.tick_params(axis='y', labelleft=True) - - # Set background color and grid for better readability - for ax_i in hist.axes.flat: - ax_i.set_facecolor('#f2f2f2') - ax_i.grid(True, which='major', axis='both') - - # Titles for each facet - hist.set_titles("{col_name}") - - # Adjust margins for readability across layouts. - hist.figure.subplots_adjust(left=.1, - top=0.9, - bottom=0.12, - hspace=0.35, - wspace=0.2) - - # Pass the figure and axes to the output for further customization - fig = hist.figure - fig.set_size_inches( - facet_fig_width or fig.get_figwidth(), - facet_fig_height or fig.get_figheight(), - ) - axs.extend(hist.axes.flat) - + ax_array = ax_array.flatten() + + for i, ax_i in enumerate(ax_array): + group_data = plot_data[plot_data[group_by] == + groups[i]][data_column] + hist_data = calculate_histogram(group_data, kwargs['bins']) + + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, + weights='count', **kwargs) + # If plotting feature specify which layer + if feature: + ax_i.set_title(f'{groups[i]} with Layer: {layer}') + else: + ax_i.set_title(f'{groups[i]}') + axs.append(ax_i) + + else: # Facet option + # Compute max label length only when not explicitly provided. + facet_tick_max_chars = 0 + if not pd.api.types.is_numeric_dtype(plot_data[data_column]): + facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) + + # Derive facet geometry based on group count and layout hints + # Returned layout keys: facet_ncol, facet_height, facet_aspect + facet_layout = _derive_facet_geometry( + n_groups=n_groups, + facet_ncol=facet_ncol, + facet_fig_width=facet_fig_width, + facet_fig_height=facet_fig_height, + facet_tick_max_chars=facet_tick_max_chars, + facet_tick_rotation=facet_tick_rotation, + ) + + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + kwargs['bins'] = shared_bins.tolist() + + # Create the FacetGrid for the histogram + hist = sns.FacetGrid( + hist_data, + col=group_by, + col_wrap=facet_layout['facet_ncol'], + height=facet_layout['facet_height'], + aspect=facet_layout['facet_aspect'], + sharex=True, + sharey=True, + ) + + # Map the histogram function to the grid + hist.map_dataframe( + sns.histplot, + x='bin_center', + weights='count', + **kwargs, + ) + + # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column + for ax_i in hist.axes.flat: + ax_i.tick_params(axis='x', labelbottom=True) + ax_i.tick_params(axis='y', labelleft=True) + + # Set background color and grid for better readability + for ax_i in hist.axes.flat: + ax_i.set_facecolor('#f2f2f2') + ax_i.grid(True, which='major', axis='both') + + # Titles for each facet + hist.set_titles("{col_name}") + + # Adjust margins for readability across layouts. + hist.figure.subplots_adjust(left=.1, + top=0.9, + bottom=0.12, + hspace=0.35, + wspace=0.2) + + # Pass the figure and axes to the output for further customization + fig = hist.figure + fig.set_size_inches( + facet_fig_width or fig.get_figwidth(), + facet_fig_height or fig.get_figheight(), + ) + axs.extend(hist.axes.flat) + else: - if ax is None: - fig, ax = plt.subplots() - + if ax is None: + fig, ax = plt.subplots() + # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): ax.set_xlim(hist_data['bin_left'].min(), - hist_data['bin_right'].max()) - + hist_data['bin_right'].max()) + sns.histplot( data=hist_data, x='bin_center', @@ -1169,80 +1169,80 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ax=ax, **kwargs ) - + # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) - - stat = kwargs.get('stat', 'count') - xlabel, ylabel = resolve_hist_axis_labels( - data_column=data_column, - x_log_scale=x_log_scale, - y_log_scale=y_log_scale, - stat=stat, - ) - - axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] - for ax in axes: - # Set axis scales if y_log_scale is True - if y_log_scale: - ax.set_yscale('log') - - if facet: - ax.set_xlabel('') - ax.set_ylabel('') - else: - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - - # Set a common x and y label for the entire figure if facet is True - if facet and fig is not None: - fig.supxlabel(xlabel) - fig.supylabel(ylabel) - + + stat = kwargs.get('stat', 'count') + xlabel, ylabel = resolve_hist_axis_labels( + data_column=data_column, + x_log_scale=x_log_scale, + y_log_scale=y_log_scale, + stat=stat, + ) + + axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] + for ax in axes: + # Set axis scales if y_log_scale is True + if y_log_scale: + ax.set_yscale('log') + + if facet: + ax.set_xlabel('') + ax.set_ylabel('') + else: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + # Set a common x and y label for the entire figure if facet is True + if facet and fig is not None: + fig.supxlabel(xlabel) + fig.supylabel(ylabel) + if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} else: return {"fig": fig, "axs": axs, "df": hist_data} - - + + def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + column : str Name of member of adata.obs to plot the histogram. - + layer : str, default None The name of the `adata` layer to use to calculate the mean feature. - + **kwargs: Parameters passed to seaborn heatmap function. - + Returns ------- pandas.DataFrame A dataframe tha has the labels as indexes the mean feature for every marker. - + matplotlib.figure.Figure The figure of the heatmap. - + matplotlib.axes._subplots.AxesSubplot The AsxesSubplot of the heatmap. - + """ features = adata.to_df(layer=layer) labels = adata.obs[column] grouped = pd.concat([features, labels], axis=1).groupby(column) mean_feature = grouped.mean() - + n_rows = len(mean_feature) n_cols = len(mean_feature.columns) fig, ax = plt.subplots(figsize=(n_cols * 1.5, n_rows * 1.5)) @@ -1257,18 +1257,18 @@ def heatmap(adata, column, layer=None, **kwargs): linewidth=.5, annot_kws={"fontsize": 10}, **kwargs) - + ax.tick_params(axis='both', labelsize=25) ax.set_ylabel(column, size=25) - + return mean_feature, fig, ax - + def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=False, standard_scale=None, z_score="annotation", swap_axes=False, rotate_label=False, **kwargs): - + """ Generates a hierarchical clustering heatmap and dendrogram. By default, the dataset is assumed to have features as columns and @@ -1276,7 +1276,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, and for each group, the average expression intensity of each feature (e.g., protein or marker) is computed. The heatmap is plotted using seaborn's clustermap. - + Parameters ---------- adata : anndata.AnnData @@ -1336,7 +1336,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, - `metric` : str The distance metric to use for the hierarchy. Defaults to 'euclidean' in the function. - + Returns ------- mean_intensity : pandas.DataFrame @@ -1349,7 +1349,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, rows and columns. These linkage matrices can be used to generate dendrograms with tools like scipy's dendrogram function. This offers flexibility in customizing and plotting dendrograms as needed. - + Examples -------- import matplotlib.pyplot as plt @@ -1359,7 +1359,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, X = pd.DataFrame([[1, 2], [3, 4]], columns=['gene1', 'gene2']) annotation = pd.DataFrame(['type1', 'type2'], columns=['cell_type']) all_data = anndata.AnnData(X=X, obs=annotation) - + mean_intensity, clustergrid, dendrogram_data = hierarchical_heatmap( all_data, "cell_type", @@ -1369,15 +1369,15 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=True ) - + # To display a standalone dendrogram using the returned linkage matrix: import scipy.cluster.hierarchy as sch import numpy as np import matplotlib.pyplot as plt - + # Convert the linkage data to type double dendro_col_data = np.array(dendrogram_data['col_linkage'], dtype=np.double) - + # Ensure the linkage matrix has at least two dimensions and more than one linkage if dendro_col_data.ndim == 2 and dendro_col_data.shape[0] > 1: @@ -1388,22 +1388,22 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, else: print("Insufficient data to plot a dendrogram.") """ - + # Use utility functions to check inputs check_annotation(adata, annotations=annotation) if features: check_feature(adata, features=features) if layer: check_table(adata, tables=layer) - + # Raise an error if there are any NaN values in the annotation column if adata.obs[annotation].isna().any(): raise ValueError("NaN values found in annotation column.") - + # Convert the observation column to categorical if it's not already if not pd.api.types.is_categorical_dtype(adata.obs[annotation]): adata.obs[annotation] = adata.obs[annotation].astype('category') - + # Calculate mean intensity if layer: intensities = pd.DataFrame( @@ -1413,25 +1413,25 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, ) else: intensities = adata.to_df() - + labels = adata.obs[annotation] grouped = pd.concat([intensities, labels], axis=1).groupby(annotation) mean_intensity = grouped.mean() - + # If swap_axes is True, transpose the mean_intensity if swap_axes: mean_intensity = mean_intensity.T - + # Map z_score based on user's input and the state of swap_axes if z_score == "annotation": z_score = 0 if not swap_axes else 1 elif z_score == "feature": z_score = 1 if not swap_axes else 0 - + # Subset the mean_intensity DataFrame based on selected features if features is not None and len(features) > 0: mean_intensity = mean_intensity.loc[features] - + # Determine clustering behavior based on swap_axes if swap_axes: row_cluster = cluster_feature # Rows are features @@ -1439,7 +1439,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, else: row_cluster = cluster_annotations # Rows are annotations col_cluster = cluster_feature # Columns are features - + # Use seaborn's clustermap for hierarchical clustering and # heatmap visualization. clustergrid = sns.clustermap( @@ -1452,29 +1452,29 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, col_cluster=col_cluster, **kwargs ) - + # Rotate x-axis tick labels if rotate_label is True if rotate_label: plt.setp(clustergrid.ax_heatmap.get_xticklabels(), rotation=45) - + # Extract the dendrogram data for return dendro_row_data = None dendro_col_data = None - + if clustergrid.dendrogram_row: dendro_row_data = clustergrid.dendrogram_row.linkage - + if clustergrid.dendrogram_col: dendro_col_data = clustergrid.dendrogram_col.linkage - + # Define the dendrogram_data dictionary dendrogram_data = { 'row_linkage': dendro_row_data, 'col_linkage': dendro_col_data } - + return mean_intensity, clustergrid, dendrogram_data - + def threshold_heatmap( adata, feature_cutoffs, annotation, layer=None, swap_axes=False, **kwargs @@ -1482,7 +1482,7 @@ def threshold_heatmap( """ Creates a heatmap for each feature, categorizing intensities into low, medium, and high based on provided cutoffs. - + Parameters ---------- adata : anndata.AnnData @@ -1501,7 +1501,7 @@ def threshold_heatmap( If True, swaps the axes of the heatmap. **kwargs : keyword arguments Additional keyword arguments to pass to scanpy's heatmap function. - + Returns ------- Dictionary of :class:`~matplotlib.axes.Axes` @@ -1511,22 +1511,22 @@ def threshold_heatmap( Potential Keys includes: 'groupby_ax', 'dendrogram_ax', and 'gene_groups_ax'. """ - + # Use utility functions for input validation check_table(adata, tables=layer) check_annotation(adata, annotations=annotation) if feature_cutoffs: check_feature(adata, features=list(feature_cutoffs.keys())) - + # Assert annotation is a string if not isinstance(annotation, str): err_type = type(annotation).__name__ err_msg = (f'Annotation should be string. Got {err_type}.') raise TypeError(err_msg) - + if not isinstance(feature_cutoffs, dict): raise TypeError("feature_cutoffs should be a dictionary.") - + for key, value in feature_cutoffs.items(): if not (isinstance(value, tuple) and len(value) == 2): raise ValueError( @@ -1537,13 +1537,13 @@ def threshold_heatmap( raise ValueError(f"Low cutoff for {key} should not be NaN.") if math.isnan(value[1]): raise ValueError(f"High cutoff for {key} should not be NaN.") - + adata.uns['feature_cutoffs'] = feature_cutoffs - + intensity_df = pd.DataFrame( index=adata.obs_names, columns=feature_cutoffs.keys() ) - + for feature, cutoffs in feature_cutoffs.items(): low_cutoff, high_cutoff = cutoffs feature_values = ( @@ -1554,17 +1554,17 @@ def threshold_heatmap( intensity_df.loc[(feature_values > low_cutoff) & (feature_values <= high_cutoff), feature] = 1 intensity_df.loc[feature_values > high_cutoff, feature] = 2 - + intensity_df = intensity_df.astype(int) adata.layers["intensity"] = intensity_df.to_numpy() adata.obs[annotation] = adata.obs[annotation].astype('category') - + color_map = {0: (0/255, 0/255, 139/255), 1: 'green', 2: 'yellow'} colors = [color_map[i] for i in range(3)] cmap = ListedColormap(colors) - + norm = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], cmap.N) - + heatmap_plot = sc.pl.heatmap( adata, var_names=intensity_df.columns, @@ -1577,18 +1577,18 @@ def threshold_heatmap( swap_axes=swap_axes, **kwargs ) - + # Print the keys of the heatmap_plot dictionary print("Keys of heatmap_plot:", heatmap_plot.keys()) - + # Get the main heatmap axis from the available keys heatmap_ax = heatmap_plot.get('heatmap_ax') - + # If 'heatmap_ax' key does not exist, access the first axis available if heatmap_ax is None: heatmap_ax = next(iter(heatmap_plot.values())) print("Heatmap Axes:", heatmap_ax) - + # Find the colorbar associated with the heatmap cbar = None for child in heatmap_ax.get_children(): @@ -1599,7 +1599,7 @@ def threshold_heatmap( print("No colorbar found in the plot.") return print("Colorbar:", cbar) - + new_ticks = [0, 1, 2] new_labels = ['Low', 'Medium', 'High'] cbar.set_ticks(new_ticks) @@ -1608,9 +1608,9 @@ def threshold_heatmap( cbar.ax.set_position( [pos_heatmap.x1 + 0.02, pos_heatmap.y0, 0.02, pos_heatmap.height] ) - + return heatmap_plot - + def spatial_plot( adata, @@ -1630,7 +1630,7 @@ def spatial_plot( ---------- adata : anndata.AnnData The AnnData object containing target feature and spatial coordinates. - + spot_size : int The size of spot on the spatial plot. alpha : float @@ -1658,7 +1658,7 @@ def spatial_plot( ------- Single or a list of class:`~matplotlib.axes.Axes`. """ - + err_msg_layer = "The 'layer' parameter must be a string, " + \ f"got {str(type(layer))}" err_msg_feature = "The 'feature' parameter must be a string, " + \ @@ -1681,86 +1681,86 @@ def spatial_plot( f"got {str(type(vmax))}" err_msg_ax = "The 'ax' parameter must be an instance " + \ f"of matplotlib.axes.Axes, got {str(type(ax))}" - + if adata is None: raise ValueError("The input dataset must not be None.") - + if not isinstance(adata, anndata.AnnData): err_msg_adata = "The 'adata' parameter must be an " + \ f"instance of anndata.AnnData, got {str(type(adata))}." raise ValueError(err_msg_adata) - + if layer is not None and not isinstance(layer, str): raise ValueError(err_msg_layer) - + if layer is not None and layer not in adata.layers.keys(): err_msg_layer_exist = f"Layer {layer} does not exists, " + \ f"available layers are {str(adata.layers.keys())}" raise ValueError(err_msg_layer_exist) - + if feature is not None and not isinstance(feature, str): raise ValueError(err_msg_feature) - + if annotation is not None and not isinstance(annotation, str): raise ValueError(err_msg_annotation) - + if annotation is not None and feature is not None: raise ValueError(err_msg_feat_annotation_coe) - + if annotation is None and feature is None: raise ValueError(err_msg_feat_annotation_non) - + if 'spatial' not in adata.obsm_keys(): err_msg = "Spatial coordinates not found in the 'obsm' attribute." raise ValueError(err_msg) - + # Extract annotation name annotation_names = adata.obs.columns.tolist() annotation_names_str = ", ".join(annotation_names) - + if annotation is not None and annotation not in annotation_names: error_text = f'The annotation "{annotation}"' + \ 'not found in the dataset.' + \ f" Existing annotations are: {annotation_names_str}" raise ValueError(error_text) - + # Extract feature name if layer is None: layer_process = adata.X else: layer_process = adata.layers[layer] - + feature_names = adata.var_names.tolist() - + if feature is not None and feature not in feature_names: error_text = f"Feature {feature} not found," + \ " please check the sample metadata." raise ValueError(error_text) - + if not isinstance(spot_size, int): raise ValueError(err_msg_spot_size) - + if not isinstance(alpha, float): raise ValueError(err_msg_alpha_type) - + if not (0 <= alpha <= 1): raise ValueError(err_msg_alpha_value) - + if vmin != -999 and not ( isinstance(vmin, float) or isinstance(vmin, int) ): raise ValueError(err_msg_vmin) - + if vmax != -999 and not ( isinstance(vmax, float) or isinstance(vmax, int) ): raise ValueError(err_msg_vmax) - + if ax is not None and not isinstance(ax, plt.Axes): raise ValueError(err_msg_ax) - + if feature is not None: - + feature_index = feature_names.index(feature) feature_annotation = feature + "spatial_plot" if vmin == -999: @@ -1773,11 +1773,11 @@ def spatial_plot( color_region = annotation vmin = None vmax = None - + if ax is None: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) - + ax = sc.pl.spatial( adata=adata, layer=layer, @@ -1789,9 +1789,9 @@ def spatial_plot( ax=ax, show=False, **kwargs) - + return ax - + def boxplot(adata, annotation=None, second_annotation=None, layer=None, ax=None, features=None, log_scale=False, **kwargs): @@ -1799,32 +1799,32 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, Create a boxplot visualization of the features in the passed adata object. This function offers flexibility in how the boxplots are displayed, based on the arguments provided. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + annotation : str, optional Annotation to determine if separate plots are needed for every label. - + second_annotation : str, optional Second annotation to further divide the data. - + layer : str, optional The name of the matrix layer to use. If not provided, uses the main data matrix adata.X. - + ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. - + features : list, optional List of feature names to be plotted. If not provided, all features will be plotted. - + log_scale : bool, optional If True, the Y-axis will be in log scale. Default is False. - + **kwargs Additional arguments to pass to seaborn.boxplot. Key arguments include: @@ -1837,7 +1837,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ------- fig, ax : matplotlib.figure.Figure, matplotlib.axes.Axes The created figure and axes for the plot. - + Examples -------- - Multiple features boxplot: boxplot(adata, features=['GeneA','GeneB']) @@ -1848,7 +1848,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, - Nested grouping by two annotations: boxplot(adata, features=['GeneA'], annotation='cell_type', second_annotation='treatment') """ - + # Use utility functions to check inputs print("Calculating Box Plot...") if layer: @@ -1859,75 +1859,75 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, check_annotation(adata, annotations=second_annotation) if features: check_feature(adata, features=features) - + if 'orient' not in kwargs: kwargs['orient'] = 'v' - + if kwargs['orient'] != 'v': v_orient = False else: v_orient = True - + # Validate ax instance if ax and not isinstance(ax, plt.Axes): raise TypeError("Input 'ax' must be a matplotlib.axes.Axes object.") - + # Use the specified layer if provided if layer: data_matrix = adata.layers[layer] else: data_matrix = adata.X - + # Create a DataFrame from the data matrix with features as columns df = pd.DataFrame(data_matrix, columns=adata.var_names) - + # Add annotations to the DataFrame if provided if annotation: df[annotation] = adata.obs[annotation].values if second_annotation: df[second_annotation] = adata.obs[second_annotation].values - + # If features is None, set it to all available features if features is None: features = adata.var_names.tolist() - + df = df[ features + ([annotation] if annotation else []) + ([second_annotation] if second_annotation else []) ] - + # Check for negative values if log_scale and (df[features] < 0).any().any(): print( "There are negative values in this data, disabling the log scale." ) log_scale = False - + # Apply log1p transformation if log_scale is True if log_scale: df[features] = np.log1p(df[features]) - + # Create the plot if ax: fig = ax.get_figure() else: fig, ax = plt.subplots(figsize=(10, 5)) - + # Plotting logic based on provided annotations if annotation and second_annotation: if v_orient: sns.boxplot(data=df, y=features[0], x=annotation, hue=second_annotation, ax=ax, **kwargs) - + else: sns.boxplot(data=df, y=annotation, x=features[0], hue=second_annotation, ax=ax, **kwargs) - + title_str = f"Nested Grouping by {annotation} and {second_annotation}" - + ax.set_title(title_str) - + elif annotation: if len(features) > 1: # Reshape the dataframe to long format for visualization @@ -1947,7 +1947,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, sns.boxplot(data=df, x=features[0], y=annotation, ax=ax, **kwargs) ax.set_title(f"Grouped by {annotation}") - + else: if len(features) > 1: if v_orient: @@ -1967,7 +1967,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ax.set_yticks([0]) # Set a single tick for the single feature ax.set_yticklabels([features[0]]) # Set the label for the tick ax.set_title("Single Boxplot") - + # Set x and y-axis labels if v_orient: xlabel = annotation if annotation else 'Feature' @@ -1979,12 +1979,12 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ylabel = annotation if annotation else 'Feature' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) - + plt.xticks(rotation=90) plt.tight_layout() - + return fig, ax, df - + def boxplot_interactive( adata, @@ -2007,59 +2007,59 @@ def boxplot_interactive( ): """ Generate a boxplot for given features from an AnnData object. - + This function visualizes the distribution of gene expression (or other features) across different annotations in the provided data. It can handle various options such as log-transformation, feature selection, and handling of outliers. - + Parameters ----------- adata : AnnData An AnnData object containing the data to plot. The expression matrix is accessed via `adata.X` or `adata.layers[layer]`, and annotations are taken from `adata.obs`. - + annotation : str, optional The name of the annotation column (e.g., cell type or sample condition) from `adata.obs` used to group the features. If `None`, no grouping is applied. - + layer : str, optional The name of the layer from `adata.layers` to use. If `None`, `adata.X` is used. - + ax : plotly.graph_objects.Figure, optional The figure to plot the boxplot onto. If `None`, a new figure is created. - + features : list of str, optional The list of features (genes) to plot. If `None`, all features are included. - + showfliers : {None, "downsample", "all"}, default = None If 'all', all outliers are displayed in the boxplot. If 'downsample', when num outliers is >10k, they are downsampled to 10% of the original count. If None, outliers are hidden. - + log_scale : bool, default=False If True, the log1p transformation is applied to the features before plotting. This option is disabled if negative values are found in the features. - + orient : {"v", "h"}, default="v" The orientation of the boxplots: "v" for vertical, "h" for horizontal. - + figure_width : int, optional Width of the figure in inches. Default is 3.2. - + figure_height : int, optional Height of the figure in inches. Default is 2. - + figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. - + defined_color_map : str, optional Key in 'adata.uns' holding a pre-computed color dictionary. Falls back to automatic generation from 'annotation' values. @@ -2072,7 +2072,7 @@ def boxplot_interactive( future enhancements. Default is None. **kwargs : dict Additional arguments for seaborn figure-level functions. - + Returns ------- A dictionary containing the following keys: @@ -2081,15 +2081,15 @@ def boxplot_interactive( - If `figure_type` is "static": A base64-encoded PNG image string - If `figure_type` is "interactive": A Plotly figure object - + df : pd.DataFrame A DataFrame containing the features and their corresponding values. - + metrics : pd.DataFrame A DataFrame containing the computed boxplot metrics (if `return_metrics` is True). """ - + def boxplot_from_statistics( summary_stats: pd.DataFrame, cmap: dict, @@ -2105,12 +2105,12 @@ def boxplot_from_statistics( ): """ Generate a boxplot from the provided summary statistics DataFrame. - + This function visualizes a set of summary statistics (e.g., quartiles, mean) as a boxplot. It supports grouping the data by a given annotation and allows customization of orientation, displaying outliers, and interactive plotting. - + Parameters ---------- summary_stats : pd.DataFrame @@ -2118,49 +2118,49 @@ def boxplot_from_statistics( plot. It should include columns like 'marker', 'q1', 'med', 'q3', 'whislo', 'whishi', and 'mean'. Optionally, it may also contain an annotation column used for grouping. - + cmap : dict A dictionary mapping annotation/feature values to color strings (hex, rgb/rgba, hsl/hsla, hsv/hsva, or CSS). - + annotation : str, optional The column name in `summary_stats` used to group the data by specific categories (e.g., cell type, condition). If `None`, no grouping is applied. - + ax : matplotlib.axes.Axes or plotly.graph_objects.Figure, optional A figure or axes to plot onto. If None, a new Plotly figure is created. - + showfliers : {None, "downsample", "all"}, default = None If 'all', all outliers are displayed in the boxplot. If 'downsample', when num outliers is >10k, they are downsampled to 10% of the original count. If None, outliers are hidden. - + log_scale : bool, optional, default=False If True, the log1p transformation is applied to the features before plotting. This option is disabled if negative values are found in the features. - + orient : {"v", "h"}, default="v" The orientation of the boxplot: 'v' for vertical and 'h' for horizontal. - + figure_width : int, optional Width of the figure in inches. Default is 3.2. - + figure_height : int, optional Height of the figure in inches. Default is 2. - + figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. - + Returns ------- fig : plotly.graph_objects.Figure The Plotly figure containing the generated boxplot. - + Notes ----- - The function uses the `plotly` library for visualization, allowing @@ -2170,36 +2170,36 @@ def boxplot_from_statistics( - The boxplot will display whiskers, quartiles, and the mean. Outliers are controlled by the `showfliers` parameter. """ - + # Initialize the figure: if 'ax' is provided, use it, otherwise create # a new Plotly figure if ax: fig = ax else: fig = go.Figure() - + # Get unique features (markers) from the summary statistics unique_features = summary_stats["marker"].unique() - + # Create comma seperated list for features in the plot title # If there are >3 unique features, use 'Multiple Features' in the title if len(unique_features) < 4: plot_title = f"{', '.join(unique_features[0:])}" else: plot_title = 'Multiple Features' - + if annotation: unique_annotations = summary_stats[annotation].unique() - + plot_title += f" grouped by {annotation}" - + # Empty outlier lists cause issues with plotly, # so replace them with [None] if showfliers: summary_stats["fliers"] = summary_stats["fliers"].apply( lambda x: [None] if len(x) == 0 else x ) - + # Set up the orientation of the plot data & axis-labels if orient == "h": x_data = "fliers" @@ -2211,7 +2211,7 @@ def boxplot_from_statistics( y_data = "fliers" x_axis_label = annotation if annotation else "feature value" y_axis_label = "log(Intensity)" if log_scale else "Intensity" - + # If annotation is provided, group the data # and create boxplots for each group if annotation: @@ -2222,7 +2222,7 @@ def boxplot_from_statistics( grouped_data[annotation_value] = summary_stats[ summary_stats[annotation] == annotation_value ].to_dict(orient="list") - + # Add a boxplot trace for each annotation value for annotation_value, data in grouped_data.items(): if orient == "h": @@ -2231,7 +2231,7 @@ def boxplot_from_statistics( else: y = data[y_data] if showfliers else None x = data[x_data] - + fig.add_trace( go.Box( name=annotation_value, @@ -2259,14 +2259,14 @@ def boxplot_from_statistics( unique_annotations = unique_annotations[ unique_annotations != annotation_value ] - + # Adjust layout to group the boxplots by annotation fig.update_layout(boxmode="group") else: # If no annotation, create a boxplot # for each unique feature (marker) stats_dict = summary_stats.to_dict(orient="list") - + for i, marker_value in enumerate(stats_dict["marker"]): if orient == "h": y = [stats_dict[y_data][i]] @@ -2274,7 +2274,7 @@ def boxplot_from_statistics( else: y = [stats_dict[y_data][i], [None]] if showfliers else None x = [stats_dict[x_data][i]] - + # Note: adding None to the x or y data to ensure # the outliers are displayed correctly fig.add_trace( @@ -2298,7 +2298,7 @@ def boxplot_from_statistics( **kwargs ) ) - + # Final layout adjustments for the plot title, axis labels, and size fig.update_layout( title=plot_title, @@ -2307,13 +2307,13 @@ def boxplot_from_statistics( height=int(figure_height * figure_dpi), width=int(figure_width * figure_dpi), ) - + return fig - + ##################### # Main Code Block # ##################### - + logging.info("Calculating Box Plot...") if layer: check_table(adata, tables=layer) @@ -2321,55 +2321,55 @@ def boxplot_from_statistics( check_annotation(adata, annotations=annotation) if features: check_feature(adata, features=features) - + if ax and not isinstance(ax, plt.Figure): raise TypeError("Input 'ax' must be a plotly.Figure object.") - + if showfliers not in ("all", "downsample", None): raise ValueError( ("showfliers must be one of 'all', 'downsample', or None."), (f" Got {showfliers}."), ) - + if figure_type not in ("interactive", "static", "png"): raise ValueError( (f"figure_type must be one of 'interactive', 'static', or 'png'."), (f" Got {figure_type}."), ) - + # Extract data from the specified layer or the default matrix (adata.X) if layer: data_matrix = adata.layers[layer] else: data_matrix = adata.X - + # Convert the data matrix into a DataFrame with # appropriate column names (features) df = pd.DataFrame(data_matrix, columns=adata.var_names) - + # Add annotation column to the DataFrame if provided if annotation: df[annotation] = adata.obs[annotation].values - + # If no specific features are provided, use all available features if features is None: features = adata.var_names.tolist() - + # Filter the DataFrame to include only the # selected features and the annotation df = df[features + ([annotation] if annotation else [])] - + # Check for negative values if log scale is requested if log_scale and (df[features] < 0).any().any(): print( "There are negative values in this data, disabling the log scale." ) log_scale = False - + # Apply log1p transformation if log_scale is True if log_scale: df[features] = np.log1p(df[features]) - + start_time = time.time() # Compute the summary statistics required for the boxplot metrics = compute_boxplot_metrics( @@ -2379,7 +2379,7 @@ def boxplot_from_statistics( "Time taken to compute boxplot metrics: %f seconds", time.time() - start_time ) - + # Get the colormap for the annotation if defined_color_map: cmap = get_defined_color_map(adata) @@ -2397,7 +2397,7 @@ def boxplot_from_statistics( color_map=feature_colorscale, return_dict=True, ) - + start_time = time.time() # Generate the boxplot figure from the summary statistics fig = boxplot_from_statistics( @@ -2413,7 +2413,7 @@ def boxplot_from_statistics( figure_dpi=figure_dpi, **kwargs, ) - + # Prepare the base image or figure return value if figure_type == "interactive": plot = fig @@ -2442,19 +2442,19 @@ def boxplot_from_statistics( 'legend_itemdoubleclick': False } plot = fig.update_layout(**config) - + logging.info( "Time taken to generate boxplot: %f seconds", time.time() - start_time ) - + result = {"fig": plot, "df": df} # Determine if metrics included based on return_metrics flag if return_metrics: result["metrics"] = metrics - + return result - + def interactive_spatial_plot( adata, @@ -2476,11 +2476,11 @@ def interactive_spatial_plot( cmax=None, **kwargs ): - + """ Create an interactive scatter plot for spatial data using provided annotations. - + Parameters ---------- adata : AnnData @@ -2532,55 +2532,55 @@ def interactive_spatial_plot( Default is None. **kwargs Additional keyword arguments for customization. - + Returns ------- list of dict A list of dictionaries, each containing the following keys: - "image_name": str, the name of the generated image. - "image_object": Plotly Figure object. - + Notes ----- This function is tailored for spatial single-cell data and expects the AnnData object to have spatial coordinates in its `.obsm` attribute under the 'spatial' key. """ - + if annotations is None and feature is None: raise ValueError( "At least one of the 'annotations' or 'feature' parameters " + \ "must be provided." ) - + if annotations is not None: if not isinstance(annotations, list): annotations = [annotations] - + for annotation in annotations: check_annotation( adata, annotations=annotation ) - + if feature is not None: check_feature( adata, features=feature ) - + if layer is not None: check_table( adata, tables=layer ) - + check_table( adata, tables='spatial', associated_table=True ) - + def prepare_spatial_dataframe( adata, annotations=None, @@ -2588,13 +2588,13 @@ def prepare_spatial_dataframe( layer=None): """ Prepare a DataFrame for spatial plotting from an AnnData object. - + If 'annotations' is provided (a string or list of strings), the returned DataFrame will contain the X,Y coordinates and one column per annotation. If 'feature' is provided (and annotations is None), a single 'color' column is created from adata.layers[layer] (if provided) or adata.X. - + Parameters ---------- adata : anndata.AnnData @@ -2605,13 +2605,13 @@ def prepare_spatial_dataframe( Continuous feature name in adata.var_names for coloring. layer : str, optional Layer to use for feature values if feature is provided. - + Returns ------- df : pandas.DataFrame DataFrame with columns 'X', 'Y' and each annotation column (or a 'color' column for continuous feature). - + Raises ------ ValueError @@ -2621,7 +2621,7 @@ def prepare_spatial_dataframe( xcoord = [coord[0] for coord in spatial] ycoord = [coord[1] for coord in spatial] df = pd.DataFrame({'X': xcoord, 'Y': ycoord}) - + if annotations is not None: if isinstance(annotations, str): annotations = [annotations] @@ -2635,7 +2635,7 @@ def prepare_spatial_dataframe( raise ValueError( "Either 'annotations' or 'feature' must be provided.") return df - + def main_figure_generation( spatial_df, annotations=None, @@ -2656,7 +2656,7 @@ def main_figure_generation( This function generates the main interactive plot using Plotly that contains the spatial scatter plot with annotations and image configuration. - + Parameters ---------- spatial_df : pandas.DataFrame @@ -2685,30 +2685,30 @@ def main_figure_generation( Font size for text in the plot. Default is 12. title : str, optional Title of the image. Default is "interactive_spatial_plot". - + Returns ------- plotly.graph_objs._figure.Figure The generated interactive Plotly figure. """ - + xcoord = spatial_df['X'] ycoord = spatial_df['Y'] - + min_x, max_x = min(xcoord), max(xcoord) min_y, max_y = min(ycoord), max(ycoord) dx = max_x - min_x - + dy = max_y - min_y - + min_x_range = min_x - 0.05 * dx max_x_range = max_x + 0.05 * dx min_y_range = min_y - 0.05 * dy max_y_range = max_y + 0.05 * dy - + width_px = int(figure_width * figure_dpi) height_px = int(figure_height * figure_dpi) - + # Define partial for scatter traces with common parameters scatter_partial = partial( px.scatter, @@ -2717,7 +2717,7 @@ def main_figure_generation( render_mode="webgl", **kwargs ) - + # Helper function to create a scatter trace for features # as it needs a continuous color scale. # in my experience, px.scatter does not work well with @@ -2741,11 +2741,11 @@ def create_scatter_trace(df, feature, colorscale): text=df[feature], **kwargs ) - + # The annotation trace creates a dummy point # so that the label of that annotion is shown in the legend def create_annotation_trace(filtered, obs): - + # add one extra point just close to the first point trace = px.scatter( x=[filtered['X'].iloc[0]-0.1], @@ -2764,14 +2764,14 @@ def create_annotation_trace(filtered, obs): name=f'{obs}' ) return trace - + main_fig = go.Figure() - + if annotations is not None: # Loop over all annotation and add annotation dummy point # and data points to the figure for obs in annotations: - + spatial_df[obs].fillna("no_label", inplace=True) filtered = spatial_df # Create and add annotation trace using the helper function @@ -2785,19 +2785,19 @@ def create_annotation_trace(filtered, obs): hover_data=[obs], color_discrete_map=color_mapping, ).data) - + elif feature is not None: - + main_fig.add_trace( create_scatter_trace(spatial_df, feature, colorscale) ) - + else: raise ValueError( "No plot is generated." " Either 'annotations' or 'feature' must be provided." ) - + if annotations is not None: # Set the hover template to show x, y and annotation # This is needed to show the correct label when @@ -2806,7 +2806,7 @@ def create_annotation_trace(filtered, obs): elif feature is not None: # it is already set in the create_scatter_trace function hovertemplate = None - + main_fig.update_traces( mode='markers', marker=dict( @@ -2815,7 +2815,7 @@ def create_annotation_trace(filtered, obs): ), hovertemplate=hovertemplate ) - + main_fig.update_layout( width=width_px, height=height_px, @@ -2872,21 +2872,21 @@ def create_annotation_trace(filtered, obs): }, margin=dict(l=5, r=5, t=font_size*2, b=5) ) - + if reverse_y_axis: main_fig.update_layout(yaxis=dict(autorange="reversed")) - + return { "image_name": f"{spell_out_special_characters(title)}.html", "image_object": main_fig } - + ##################### # Main Code Block ## ##################### - + from functools import partial - + # Set the discrete or continuous color parameters color_dict = None colorscale = None @@ -2905,7 +2905,7 @@ def create_annotation_trace(filtered, obs): f'Colored by "{feature}", ' f'table: "{layer if layer else "Original"}"' ) - + # Create the partial function with the common keyword arguments directly plot_main = partial( main_figure_generation, @@ -2921,18 +2921,18 @@ def create_annotation_trace(filtered, obs): font_size=font_size, **kwargs ) - + results = [] - + if stratify_by is not None: # Check if the stratification column exists in the data check_annotation(adata, annotations=stratify_by) unique_stratification_values = adata.obs[stratify_by].unique() - + for strat_value in unique_stratification_values: condition = adata.obs[stratify_by] == strat_value title_str = f"Subsetting {stratify_by}: {strat_value}" - + indices = np.where(condition)[0] print(f"number of cells in the region: {len(adata.obsm['spatial'][indices])}") adata_subset = select_values( @@ -2940,7 +2940,7 @@ def create_annotation_trace(filtered, obs): annotation=stratify_by, values=strat_value ) - + spatial_df = prepare_spatial_dataframe( adata_subset, annotations=annotations, @@ -2963,16 +2963,16 @@ def create_annotation_trace(filtered, obs): feature=feature, layer=layer ) - + # For non-stratified case, pass extra parameters if needed result = plot_main( spatial_df, title=title_str ) results.append(result) - + return results - + def sankey_plot( adata: anndata.AnnData, @@ -2989,7 +2989,7 @@ def sankey_plot( source annotation, and tab20c for target annotation. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html - + Parameters ---------- adata : anndata.AnnData @@ -3007,13 +3007,13 @@ def sankey_plot( prefix : bool, optional Whether to prefix the target labels with the source labels. Defaults to True. - + Returns ------- plotly.graph_objs._figure.Figure The generated Sankey plot. """ - + label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, @@ -3024,11 +3024,11 @@ def sankey_plot( source_labels = label_relations["source"].unique().tolist() target_labels = label_relations["target"].unique().tolist() all_labels = source_labels + target_labels - + source_label_colors = color_mapping(source_labels, source_color_map) target_label_colors = color_mapping(target_labels, target_color_map) label_colors = source_label_colors + target_label_colors - + # Create a dictionary to map labels to indices label_to_index = { label: index for index, label in enumerate(all_labels)} @@ -3041,7 +3041,7 @@ def sankey_plot( target_indices = [] values = [] link_colors = [] - + # For each row in label_relations, add the source index, target index, # and count to the respective lists for _, row in label_relations.iterrows(): @@ -3049,7 +3049,7 @@ def sankey_plot( target_indices.append(label_to_index[row['target']]) values.append(row['count']) link_colors.append(color_to_map[row['source']]) - + # Generate Sankey diagram # Calculate the x-coordinate for each label fig = go.Figure(go.Sankey( @@ -3073,7 +3073,7 @@ def sankey_plot( size=sankey_font ) )) - + fig.data[0].link.customdata = label_relations[ ['percentage_source', 'percentage_target'] ] @@ -3083,7 +3083,7 @@ def sankey_plot( 'Count: %{value}' ) fig.data[0].link.hovertemplate = hovertemplate - + # Customize the Sankey diagram layout fig.update_layout( title_text=( @@ -3096,15 +3096,15 @@ def sankey_plot( color="black" # Set the title font color ) ) - + fig.update_layout(margin=dict( l=10, r=10, t=sankey_font * 3, b=sankey_font)) - + return fig - + def relational_heatmap( adata: anndata.AnnData, @@ -3118,7 +3118,7 @@ def relational_heatmap( The color map refers to matplotlib color maps, default is mint. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html - + Parameters ---------- adata : anndata.AnnData @@ -3131,7 +3131,7 @@ def relational_heatmap( The color map to use for the relational heatmap. Default is mint. **kwargs : dict, optional Additional keyword arguments. For example, you can pass font_size=12.0. - + Returns ------- dict @@ -3150,41 +3150,41 @@ def relational_heatmap( # Default font size font_size = kwargs.get('font_size', 12.0) prefix = kwargs.get('prefix', True) - + # Get the relationship between source and target annotations - + label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, target_annotation=target_annotation, prefix=prefix ) - + # Pivot the data to create a matrix for the heatmap heatmap_matrix = label_relations.pivot( index='source', columns='target', values='percentage_source' ) - + heatmap_matrix = heatmap_matrix.fillna(0) - + x = list(heatmap_matrix.columns) y = list(heatmap_matrix.index) - + # Create text labels for the heatmap label_relations['text_label'] = [ '{}%'.format(val) for val in label_relations["percentage_source"] ] - + heatmap_matrix2 = label_relations.pivot( index='source', columns='target', values='percentage_source' ) - + heatmap_matrix2 = heatmap_matrix2.fillna(0) - + hover_template = 'Source: %{z}%
Target: %{customdata}%' # Ensure alignment of the text data with the heatmap matrix z = list() @@ -3203,7 +3203,7 @@ def relational_heatmap( 0 if len(z_data_point) == 0 else z_data_point.iloc[0] ) z.append([_ for _ in iter_list]) - + # Create heatmap fig = ff.create_annotated_heatmap( z=z, @@ -3211,7 +3211,7 @@ def relational_heatmap( customdata=heatmap_matrix2.values, hovertemplate=hover_template ) - + fig.update_layout( overwrite=True, xaxis=dict( @@ -3238,15 +3238,15 @@ def relational_heatmap( b=font_size * 2 ) ) - + for i in range(len(fig.layout.annotations)): fig.layout.annotations[i].font.size = font_size - + fig.update_xaxes( side="bottom", tickangle=90 ) - + # Data output section data = fig.data[0] layout = fig.layout @@ -3256,13 +3256,13 @@ def relational_heatmap( matrix.columns=layout['xaxis']['ticktext'] matrix["total"] = matrix.sum(axis=1) matrix = matrix.fillna(0) - + # Display the DataFrame file_name = f"{source_annotation}_to_{target_annotation}" + \ "_relation_matrix.csv" - + return {"figure": fig, "file_name": file_name, "data": matrix} - + def plot_ripley_l( adata, @@ -3274,7 +3274,7 @@ def plot_ripley_l( """ Plot Ripley's L statistic for multiple bins and different regions for a given pair of phenotypes. - + Parameters ---------- adata : AnnData @@ -3290,19 +3290,19 @@ def plot_ripley_l( Whether to return the DataFrame containing the Ripley's L results. kwargs : dict, optional Additional keyword arguments to pass to `seaborn.lineplot`. - + Raises ------ ValueError If the Ripley L results are not found in `adata.uns['ripley_l']`. - + Returns ------- ax : matplotlib.axes.Axes The Axes object containing the plot, which can be further modified. df : pandas.DataFrame, optional The DataFrame containing the Ripley's L results, if `return_df` is True. - + Example ------- >>> ax = plot_ripley_l( @@ -3310,24 +3310,24 @@ def plot_ripley_l( ... phenotypes=('Phenotype1', 'Phenotype2'), ... regions=['region1', 'region2']) >>> plt.show() - + This returns the `Axes` object for further customization and displays the plot. """ - + # Retrieve the results from adata.uns['ripley_l'] ripley_results = adata.uns.get('ripley_l') - + if ripley_results is None: raise ValueError( "Ripley L results not found in the analsyis." ) - + # Filter the results for the specific pair of phenotypes filtered_results = ripley_results[ (ripley_results['center_phenotype'] == phenotypes[0]) & (ripley_results['neighbor_phenotype'] == phenotypes[1]) ] - + if filtered_results.empty: # Generate all unique combinations of phenotype pairs unique_pairs = ripley_results[ @@ -3338,12 +3338,12 @@ def plot_ripley_l( f'\nNeighbor Phenotype: "{phenotypes[1]}"' f"\nExisiting unique pairs: {unique_pairs}" ) - + # If specific regions are provided, filter them, otherwise plot all regions if regions is not None: filtered_results = filtered_results[ filtered_results['region'].isin(regions)] - + # Check if the results are emply after subsetting the regions if filtered_results.empty: available_regions = ripley_results['region'].unique() @@ -3351,16 +3351,16 @@ def plot_ripley_l( f"No data available for the specified regions: {regions}. " f"Available regions: {available_regions}." ) - + # Create a figure and axes fig, ax = plt.subplots(figsize=(10, 10)) - + plot_data = [] - + # Plot Ripley's L for each region for _, row in filtered_results.iterrows(): region = row['region'] # Region label - + if row['ripley_l'] is None: message = ( f"Ripley L results not found for region: {region}" @@ -3383,18 +3383,18 @@ def plot_ripley_l( label=f'{region}: {n_cells}, {int(area)}', ax=ax, **kwargs) - + # Calculate averages for simulations if enabled if sims: sims_stat_df = row["ripley_l"]["sims_stat"] avg_stats = sims_stat_df.groupby("bins")["stats"].mean() avg_used_center_cells = \ sims_stat_df.groupby("bins")["used_center_cells"].mean() - + # Prepare plotted data to return if return_df is True l_stat_data = row['ripley_l']['L_stat'] for _, stat_row in l_stat_data.iterrows(): - + entry = { 'region': region, 'radius': stat_row['bins'], @@ -3404,15 +3404,15 @@ def plot_ripley_l( 'n_neighbor': n_neighbors, 'used_center_cells': stat_row['used_center_cells'] } - + if sims: entry['avg_sim_ripley(radius)'] = \ avg_stats.get(stat_row['bins'], None) entry['avg_sim_used_center_cells'] = \ avg_used_center_cells.get(stat_row['bins'], None) - + plot_data.append(entry) - + if sims: confidence_level = 95 errorbar = ("pi", confidence_level) @@ -3425,7 +3425,7 @@ def plot_ripley_l( label=f"Simulations({region}):{n_sims} runs", **kwargs ) - + # Set labels, title, and grid ax.set_title( "Ripley's L Statistic for phenotypes:" @@ -3433,17 +3433,17 @@ def plot_ripley_l( ) ax.legend(title='Regions:(center, neighbor), area', loc='upper left') ax.grid(True) - + # Set the horizontal axis lable ax.set_xlabel("Radii (pixels)") ax.set_ylabel("Ripley's L Statistic") - + if return_df: df = pd.DataFrame(plot_data) return fig, df - + return fig - + def _prepare_spatial_distance_data( adata, @@ -3456,7 +3456,7 @@ def _prepare_spatial_distance_data( ): """ Prepares a tidy DataFrame for nearest-neighbor (spatial distance) plotting. - + This function: 1) Validates required parameters (annotation, distance_from). 2) Retrieves the spatial distance matrix from @@ -3468,9 +3468,9 @@ def _prepare_spatial_distance_data( 6) Reshapes (melts) into long-form data: columns -> [cellid, group, distance]. 7) Applies optional log1p transform. - + The resulting DataFrame is suitable for plotting with tool like Seaborn. - + Parameters ---------- adata : anndata.AnnData @@ -3491,7 +3491,7 @@ def _prepare_spatial_distance_data( log : bool, optional If True, applies np.log1p transform to the 'distance' column, which is renamed to 'log_distance'. - + Returns ------- pd.DataFrame @@ -3501,14 +3501,14 @@ def _prepare_spatial_distance_data( - 'distance': the numeric distance value. - 'phenotype': the reference phenotype ('distance_from'). - 'stratify_by': optional grouping column, if provided. - + Raises ------ ValueError If required parameters are missing, if phenotypes are not found in `adata.obs`, or if the spatial distance matrix is not available in `adata.obsm`. - + Examples -------- >>> df_long = _prepare_spatial_distance_data( @@ -3522,7 +3522,7 @@ def _prepare_spatial_distance_data( ... ) >>> df_long.head() """ - + # Validate required parameters if distance_from is None: raise ValueError( @@ -3530,15 +3530,15 @@ def _prepare_spatial_distance_data( "the reference group from which distances are measured." ) check_annotation(adata, annotations=annotation) - + # Convert distance_to to list if needed if distance_to is not None and isinstance(distance_to, str): distance_to = [distance_to] - + phenotypes_to_check = [distance_from] + ( distance_to if distance_to else [] ) - + # Ensure distance_from and distance_to exist in adata.obs[annotation] check_label( adata, @@ -3546,7 +3546,7 @@ def _prepare_spatial_distance_data( labels=phenotypes_to_check, should_exist=True ) - + # Retrieve the spatial distance matrix from adata.obsm if spatial_distance not in adata.obsm: raise ValueError( @@ -3556,7 +3556,7 @@ def _prepare_spatial_distance_data( f"Available keys: {list(adata.obsm.keys())}" ) distance_map = adata.obsm[spatial_distance].copy() - + # Verify requested phenotypes exist in the distance_map columns missing_cols = [ p for p in phenotypes_to_check if p not in distance_map.columns @@ -3567,17 +3567,17 @@ def _prepare_spatial_distance_data( f"'{spatial_distance}'. Columns present: " f"{list(distance_map.columns)}" ) - + # Validate 'stratify_by' column if provided if stratify_by is not None: check_annotation(adata, annotations=stratify_by) - + # Build a meta DataFrame with phenotype & optional stratify column meta_data = pd.DataFrame({'phenotype': adata.obs[annotation]}, index=adata.obs.index) if stratify_by: meta_data[stratify_by] = adata.obs[stratify_by] - + # Merge metadata with distance_map and filter for 'distance_from' df_merged = meta_data.join(distance_map, how='left') df_merged = df_merged[df_merged['phenotype'] == distance_from] @@ -3585,15 +3585,15 @@ def _prepare_spatial_distance_data( raise ValueError( f"No cells found with phenotype == '{distance_from}'." ) - + # Reset index to ensure cell names are in a column called 'cellid' df_merged = df_merged.reset_index().rename(columns={'index': 'cellid'}) - + # Prepare the list of metadata columns meta_cols = ['phenotype'] if stratify_by: meta_cols.append(stratify_by) - + # Determine distance columns if distance_to: keep_cols = ['cellid'] + meta_cols + distance_to @@ -3605,41 +3605,41 @@ def _prepare_spatial_distance_data( c for c in df_merged.columns if c not in non_distance_cols ] keep_cols = ['cellid'] + meta_cols + distance_columns - + df_merged = df_merged[keep_cols] - + # Melt the DataFrame from wide to long format df_long = df_merged.melt( id_vars=['cellid'] + meta_cols, var_name='group', value_name='distance' ) - + # Convert columns to categorical for consistency for col in ['group', 'phenotype', stratify_by]: if col and col in df_long.columns: df_long[col] = df_long[col].astype(str).astype('category') - + # Reorder categories for 'group' if 'distance_to' is provided if distance_to: df_long['group'] = df_long['group'].cat.reorder_categories(distance_to) df_long.sort_values('group', inplace=True) - + # Ensure 'distance' is numeric and apply log transform if requested df_long['distance'] = pd.to_numeric(df_long['distance'], errors='coerce') if log: df_long['distance'] = np.log1p(df_long['distance']) df_long.rename(columns={'distance': 'log_distance'}, inplace=True) - + # Reorder columns dynamically based on the presence of 'log' distance_col = 'log_distance' if log else 'distance' final_cols = ['cellid', 'group', distance_col, 'phenotype'] if stratify_by is not None: final_cols.append(stratify_by) df_long = df_long[final_cols] - + return df_long - + def _plot_spatial_distance_dispatch( df_long, @@ -3655,7 +3655,7 @@ def _plot_spatial_distance_dispatch( """ Dispatch a seaborn call to visualise nearest-neighbor distances. Returns Axes object(s) for further customization. - + Layout logic ------------ 1. ``stratify_by`` & ``facet_plot`` → Faceted plot, returns ``Axes`` @@ -3664,7 +3664,7 @@ def _plot_spatial_distance_dispatch( ``List[Axes]`` for the "ax" key. 3. ``stratify_by`` is None → Single plot, returns ``Axes`` or ``List[Axes]`` (if plot_type creates facets) for the "ax" key. - + Parameters ---------- df_long : pd.DataFrame @@ -3698,7 +3698,7 @@ def _plot_spatial_distance_dispatch( **kwargs Extra keyword args propagated to Seaborn. Legend control (e.g. `legend=False`) should be passed here if needed. - + Returns ------- dict @@ -3707,10 +3707,10 @@ def _plot_spatial_distance_dispatch( 'ax' : matplotlib.axes.Axes | list[Axes] } """ - + if method not in ("numeric", "distribution"): raise ValueError("`method` must be 'numeric' or 'distribution'.") - + # Choose plotting function if method == "numeric": _plot_base = partial( @@ -3731,29 +3731,29 @@ def _plot_spatial_distance_dispatch( kind=plot_type, palette=palette, ) - + # Single plotting wrapper to create Axes object(s) def _make_axes_object(_data, **kws_plot): g = _plot_base(data=_data, **kws_plot) - + axis_label = ( "Log(Nearest Neighbor Distance)" if "log" in distance_col else "Nearest Neighbor Distance" ) - + g.set_axis_labels(axis_label, None) - + if g.axes.size == 1: returned_ax = g.ax else: returned_ax = g.axes.flatten().tolist() - + return returned_ax - + # Build axes final_axes_object = None - + if stratify_by and facet_plot: final_axes_object = _make_axes_object( df_long, col=stratify_by, **kwargs @@ -3772,9 +3772,9 @@ def _make_axes_object(_data, **kws_plot): final_axes_object = list_of_all_axes else: final_axes_object = _make_axes_object(df_long, **kwargs) - + return {"data": df_long, "ax": final_axes_object} - + # Build a master HEX palette and cache it inside the AnnData object # ----------------------------------------------------------------------------- @@ -3794,7 +3794,7 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): """ Normalise a CSS-style color string to a hexadecimal value or a valid Matplotlib color name. - + Parameters ---------- col : str @@ -3803,22 +3803,22 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): * 'rgb(r,g,b)' or 'rgba(r,g,b,a)', where r, g, b are 0-255 and a is 0-1 or 0-255 * any named Matplotlib color - + keep_alpha : bool, optional If True and the input includes alpha, return an 8-digit hex; otherwise drop the alpha channel. Default is False. - + Returns ------- str * Lower-case colour name or * 6- or 8-digit lower-case hex. - + Raises ------ ValueError If the color cannot be interpreted. - + Examples -------- >>> _css_rgb_or_hex_to_hex('gold') @@ -3828,9 +3828,9 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): >>> _css_rgb_or_hex_to_hex('rgba(255,0,0,0.5)', keep_alpha=True) '#ff000080' """ - + col = col.strip().lower() - + # Compile the rgb()/rgba() matcher locally to satisfy style request. rgb_re = re.compile( r'rgba?\s*\(' @@ -3841,11 +3841,11 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): r'\s*\)', re.I, ) - + # 1. direct hex if col.startswith('#'): return mcolors.to_hex(col, keep_alpha=keep_alpha).lower() - + # 2. rgb()/rgba() match = rgb_re.fullmatch(col) if match: @@ -3862,14 +3862,14 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): a_val /= 255 rgba.append(a_val) return mcolors.to_hex(rgba, keep_alpha=keep_alpha).lower() - + # 3. named color if col in mcolors.get_named_colors_mapping(): return col # let Matplotlib handle named colors - + # 4. unsupported format raise ValueError(f'Unsupported color format: "{col}"') - + # Helper function (can be defined at module level) def _ordered_unique_figs(axes_list: list): @@ -3883,7 +3883,7 @@ def _ordered_unique_figs(axes_list: list): if fig is not None: seen.setdefault(fig, None) return list(seen) - + def visualize_nearest_neighbor( adata, @@ -3904,21 +3904,21 @@ def visualize_nearest_neighbor( """ Visualize nearest-neighbor (spatial distance) data between groups of cells with optional pin-color map via numeric or distribution plots. - + This landing function first constructs a tidy long-form DataFrame via function `_prepare_spatial_distance_data`, then dispatches plotting to function `_plot_spatial_distance_dispatch`. A pin-color feature guarantees consistent mapping from annotation labels to colors across figures, drawing the mapping from ``adata.uns`` (if present) or generating one automatically through `spac.utils.color_mapping`. - + Plot arrangement logic: 1) If stratify_by is not None and facet_plot=True => single figure with subplots (faceted). 2) If stratify_by is not None and facet_plot=False => multiple separate figures, one per group. 3) If stratify_by is None => a single figure with one plot. - + Parameters ---------- adata : anndata.AnnData @@ -3959,7 +3959,7 @@ def visualize_nearest_neighbor( Only works if plotting a single component. Default is None. **kwargs : dict Additional arguments for seaborn figure-level functions. - + Returns ------- dict @@ -3969,12 +3969,12 @@ def visualize_nearest_neighbor( 'ax': matplotlib.axes.Axes | list[matplotlib.axes.Axes], 'palette': dict # {label: '#rrggbb'} } - + Raises ------ ValueError If required parameters are invalid. - + Examples -------- >>> res = visualize_nearest_neighbor( @@ -3993,19 +3993,19 @@ def visualize_nearest_neighbor( >>> df = res['data'] # long-form DataFrame >>> ax_list[0].set_title('Tumour → Stroma distances') """ - + if method not in ['numeric', 'distribution']: raise ValueError( "Invalid 'method'. Please choose 'numeric' or 'distribution'." ) - + # Determine plot_type if not provided if plot_type is None: plot_type = 'boxen' if method == 'numeric' else 'kde' - + # If log=True, the column name is 'log_distance', else 'distance' distance_col = 'log_distance' if log else 'distance' - + # Build/fetch color palette color_dict_rgb = get_defined_color_map( adata=adata, @@ -4013,14 +4013,14 @@ def visualize_nearest_neighbor( annotations=annotation, colorscale=annotation_colorscale ) - + palette_hex = { k: _css_rgb_or_hex_to_hex(v) for k, v in color_dict_rgb.items() } adata.uns.setdefault('_spac_palettes', {})[ f"{defined_color_map or annotation}_hex" ] = palette_hex - + # Reshape data df_long = _prepare_spatial_distance_data( adata=adata, @@ -4031,7 +4031,7 @@ def visualize_nearest_neighbor( distance_to=distance_to, log=log ) - + # Filter the full palette to include only the target groups present in # df_long['group']. These are the groups that will actually be used for hue # in the plot. @@ -4047,13 +4047,13 @@ def visualize_nearest_neighbor( # plot. For each label we look up its HEX code in ``palette_hex``; if a # colour exists we copy the mapping into the new dictionary. target_groups_in_plot = df_long['group'].astype(str).unique() - + plot_specific_palette = { str(group): palette_hex.get(str(group)) for group in target_groups_in_plot if palette_hex.get(str(group)) is not None } - + # Assemble kwargs & dispatch # Inject the palette into the plotting dispatcher # ----------------------------------------------------------------------------- @@ -4068,7 +4068,7 @@ def visualize_nearest_neighbor( # HOW ``dispatch_kwargs`` starts as a copy of any user‑supplied kwargs; the # call to ``update`` adds these palette‑related keys before control is # handed off to the generic plotting helper. - + dispatch_kwargs = dict(kwargs) dispatch_kwargs.update({ 'hue_axis': 'group', @@ -4076,11 +4076,11 @@ def visualize_nearest_neighbor( }) if method == 'numeric': dispatch_kwargs.setdefault('saturation', 1.0) - + # Set legend=False to allow for custom legend creation by the caller # The user can still override this by passing legend=True in kwargs dispatch_kwargs.setdefault('legend', False) - + disp = _plot_spatial_distance_dispatch( df_long=df_long, method=method, @@ -4090,15 +4090,15 @@ def visualize_nearest_neighbor( distance_col=distance_col, **dispatch_kwargs ) - + returned_axes = disp['ax'] fig_object = None # Initialize - + if isinstance(returned_axes, list): if returned_axes: # Unique figures, preserved in axis order unique_figs_ordered = _ordered_unique_figs(returned_axes) - + if unique_figs_ordered: # at least one valid figure if stratify_by and not facet_plot: # one figure per category → return the ordered list @@ -4120,35 +4120,35 @@ def visualize_nearest_neighbor( # single Axes → grab its figure fig_object = getattr(returned_axes, 'figure', None) # returned_axes is None → fig_object stays None - + return { 'data': disp['data'], 'fig': fig_object, 'ax': disp['ax'], 'palette': plot_specific_palette # Return the filtered palette } - + import json import plotly.graph_objects as go - + def present_summary_as_html(summary_dict: dict) -> str: """ Build an HTML string that presents the summary information intuitively. - + For each specified column, the HTML includes: - Column name and data type - Count and list of missing indices - Summary details presented in a table (for numeric: stats; categorical: unique values and counts) - + Parameters ---------- summary_dict : dict The summary dictionary returned by summarize_dataframe. - + Returns ------- str @@ -4167,7 +4167,7 @@ def present_summary_as_html(summary_dict: dict) -> str: "" "

Data Summary

" ) - + for col, info in summary_dict.items(): html += ( f"

Column: {col}

" @@ -4182,28 +4182,28 @@ def present_summary_as_html(summary_dict: dict) -> str: for key, val in info['summary'].items(): html += f"{key}{val}" html += "
" - + html += "" return html - + def present_summary_as_figure(summary_dict: dict) -> go.Figure: """ Build a static Plotly figure (using a table) to depict the summary dictionary. - + The figure includes columns: - Column name - Data type - Count of missing values - Missing indices (as a string) - Summary details (formatted as JSON for readability) - + Parameters ---------- summary_dict : dict The summary dictionary returned from summarize_dataframe. - + Returns ------- plotly.graph_objects.Figure @@ -4214,13 +4214,13 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: missing_counts = [] missing_indices = [] summaries = [] - + for col, info in summary_dict.items(): col_names.append(col) data_types.append(info['data_type']) missing_counts.append(info['count_missing_indices']) missing_indices.append(str(info['missing_indices'])) - + # need to convert nmpy int64 and float64 to native int and float # so that I can dump them as json clean_data = {} @@ -4234,9 +4234,9 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: else: # Keep the value as is if it's already a standard type clean_data[k] = v - + summaries.append(json.dumps(clean_data, indent=2)) - + fig = go.Figure( data=[go.Table( header=dict( diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 7f3ad365..630692eb 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -1006,7 +1006,7 @@ def test_facet_plot_shared_bins_consistency_numeric(self): np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), "Facet numeric y-ticks should remain shared across panels." ) - + # Check that the returned DataFrame has expected structure and content self.assertEqual( set(df.columns), @@ -1084,7 +1084,7 @@ def test_facet_plot_shared_bins_consistency_categorical(self): np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), "Facet categorical y-ticks should remain shared across panels." ) - + # Check that the returned DataFrame has expected structure and content self.assertEqual( set(df.columns), From c02fa6a8b4b698e85e442e8bbf1e0e28b658f279 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 22 Apr 2026 01:37:26 -0400 Subject: [PATCH 51/57] style: whitespace fix in multiple files --- src/spac/utils.py | 18 +- src/spac/visualization.py | 2030 ++++++++++---------- tests/test_visualization/test_histogram.py | 50 +- 3 files changed, 1049 insertions(+), 1049 deletions(-) diff --git a/src/spac/utils.py b/src/spac/utils.py index 5c26c284..5d5d87b3 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -8,12 +8,12 @@ import warnings import numbers from scipy.stats import median_abs_deviation -from typing import Any, List, Optional +from typing import Any, List, Optional # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) def regex_search_list( @@ -1197,20 +1197,20 @@ def compute_metrics(data): # compute summary statistics for the specified columns def compute_summary_qc_stats( - df: pd.DataFrame, + df: pd.DataFrame, n_mad: int = 5, upper_quantile: float = 0.95, lower_quantile: float = 0.05, stat_columns_list: List[str] = ['nFeature', 'nCount', 'percent.mt'] ) -> pd.DataFrame: - + """ Compute summary quality control statistics for specified columns in a dataset. For each column in stat_columns_list, this function calculates: - Mean - Median - - Upper and lower thresholds based on median ± n_mad * MAD + - Upper and lower thresholds based on median ± n_mad * MAD (median absolute deviation) - Upper and lower quantiles @@ -1231,7 +1231,7 @@ def compute_summary_qc_stats( ------- pd.DataFrame DataFrame with summary statistics for each specified column. - Columns: ["metric_name", "mean", "median", "upper_mad", "lower_mad", + Columns: ["metric_name", "mean", "median", "upper_mad", "lower_mad", "upper_quantile", "lower_quantile"] Raises @@ -1270,8 +1270,8 @@ def compute_summary_qc_stats( return pd.DataFrame( stat_vals, columns=[ - "metric_name", "mean", "median", - "upper_mad", "lower_mad", + "metric_name", "mean", "median", + "upper_mad", "lower_mad", "upper_quantile", "lower_quantile" ] - ) + ) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index abfea309..8bf794f7 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -26,17 +26,17 @@ import time import json import re -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional import matplotlib.colors as mcolors import matplotlib.patches as mpatch from functools import partial from collections import OrderedDict - + # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - + def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, @@ -46,7 +46,7 @@ def visualize_2D_scatter( ): """ Visualize 2D data using plt.scatter. - + Parameters ---------- x, y : array-like @@ -74,7 +74,7 @@ def visualize_2D_scatter( Description of what the colors represent. **kwargs Additional keyword arguments passed to plt.scatter. - + Returns ------- fig : matplotlib.figure.Figure @@ -82,7 +82,7 @@ def visualize_2D_scatter( ax : matplotlib.axes.Axes The axes of the plot. """ - + # Input validation if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): raise ValueError("x and y must be array-like.") @@ -90,7 +90,7 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - + # Define color themes themes = { 'fire': plt.get_cmap('inferno'), @@ -103,34 +103,34 @@ def visualize_2D_scatter( 'darkred': ListedColormap(['#8B0000']), 'darkgreen': ListedColormap(['#006400']) } - + if theme and theme not in themes: error_msg = ( f"Theme '{theme}' not recognized. Please use a valid theme." ) raise ValueError(error_msg) cmap = themes.get(theme, plt.get_cmap('viridis')) - + # Determine point size num_points = len(x) if point_size is None: point_size = 5000 / num_points - + # Get figure size and fontsize from kwargs or set defaults fig_width = kwargs.get('fig_width', 10) fig_height = kwargs.get('fig_height', 8) fontsize = kwargs.get('fontsize', 12) - + if ax is None: fig, ax = plt.subplots(figsize=(fig_width, fig_height)) else: fig = ax.figure - + # Plotting logic if labels is not None: # Check if labels are categorical if pd.api.types.is_categorical_dtype(labels): - + # Determine how to access the categories based on # the type of 'labels' if isinstance(labels, pd.Series): @@ -142,16 +142,16 @@ def visualize_2D_scatter( "Expected labels to be of type Series[Categorical] or " "Categorical." ) - + # Combine colors from multiple colormaps cmap1 = plt.get_cmap('tab20') cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors - + # Use the number of unique clusters to set the colormap length cmap = ListedColormap(colors[:len(unique_clusters)]) - + for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster ax.scatter( @@ -161,7 +161,7 @@ def visualize_2D_scatter( s=point_size ) print(f"Cluster: {cluster}, Points: {np.sum(mask)}") - + if annotate_centers: center_x = np.mean(x[mask]) center_y = np.mean(y[mask]) @@ -175,7 +175,7 @@ def visualize_2D_scatter( bbox_to_anchor=(1.25, 1), # Adjusting position title=f"Color represents: {color_representation}" ) - + else: # If labels are continuous scatter = ax.scatter( @@ -188,20 +188,20 @@ def visualize_2D_scatter( ) else: scatter = ax.scatter(x, y, c='gray', s=point_size, **kwargs) - + # Equal aspect ratio for the axes ax.set_aspect('equal', 'datalim') - + # Set axis labels ax.set_xlabel(x_axis_title) ax.set_ylabel(y_axis_title) - + # Set plot title if plot_title is not None: ax.set_title(plot_title) - + return fig, ax - + def dimensionality_reduction_plot( adata, @@ -214,7 +214,7 @@ def dimensionality_reduction_plot( **kwargs): """ Visualize scatter plot in PCA, t-SNE, UMAP, or associated table. - + Parameters ---------- adata : anndata.AnnData @@ -242,7 +242,7 @@ def dimensionality_reduction_plot( **kwargs Parameters passed to visualize_2D_scatter function, including point_size. - + Returns ------- fig : matplotlib.figure.Figure @@ -250,13 +250,13 @@ def dimensionality_reduction_plot( ax : matplotlib.axes.Axes The axes of the plot. """ - + # Check if both annotation and feature are specified, raise error if so if annotation and feature: raise ValueError( "Please specify either an annotation or a feature for coloring, " "not both.") - + # Use utility functions for input validation if layer: check_table(adata, tables=layer) @@ -264,21 +264,21 @@ def dimensionality_reduction_plot( check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=[feature]) - + # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca'] if method not in valid_methods: raise ValueError("Method should be one of {'tsne', 'umap', 'pca'}" f'. Got:"{method}"') - + key = f'X_{method}' if key not in adata.obsm.keys(): raise ValueError( f"{key} coordinates not found in adata.obsm. " f"Please run {method.upper()} before calling this function." ) - + else: check_table( adata=adata, @@ -286,7 +286,7 @@ def dimensionality_reduction_plot( should_exist=True, associated_table=True ) - + associated_table_shape = adata.obsm[associated_table].shape if associated_table_shape[1] != 2: raise ValueError( @@ -294,12 +294,12 @@ def dimensionality_reduction_plot( f' two dimensions. It shape is:"{associated_table_shape}"' ) key = associated_table - + print(f'Running visualization using the coordinates: "{key}"') - + # Extract the 2D coordinates x, y = adata.obsm[key].T - + # Determine coloring scheme if annotation: color_values = adata.obs[annotation].astype('category').values @@ -311,7 +311,7 @@ def dimensionality_reduction_plot( else: color_values = None color_representation = None - + # Set axis titles based on method and color representation if method == 'tsne': x_axis_title = 't-SNE 1' @@ -329,13 +329,13 @@ def dimensionality_reduction_plot( x_axis_title = f'{associated_table} 1' y_axis_title = f'{associated_table} 2' plot_title = f'{associated_table}-{color_representation}' - + # Remove conflicting keys from kwargs kwargs.pop('x_axis_title', None) kwargs.pop('y_axis_title', None) kwargs.pop('plot_title', None) kwargs.pop('color_representation', None) - + fig, ax = visualize_2D_scatter( x=x, y=y, @@ -347,14 +347,14 @@ def dimensionality_reduction_plot( color_representation=color_representation, **kwargs ) - + return fig, ax - + def tsne_plot(adata, color_column=None, ax=None, **kwargs): """ Visualize scatter plot in tSNE basis. - + Parameters ---------- adata : anndata.AnnData @@ -367,7 +367,7 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): If not provided, a new figure and axes will be created. **kwargs Parameters passed to scanpy.pl.tsne function. - + Returns ------- fig : matplotlib.figure.Figure @@ -377,221 +377,221 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): """ if not isinstance(adata, anndata.AnnData): raise ValueError("adata must be an AnnData object.") - + if 'X_tsne' not in adata.obsm: err_msg = ("adata.obsm does not contain 'X_tsne', " "perform t-SNE transformation first.") raise ValueError(err_msg) - + # Create a new figure and axes if not provided if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() - + if color_column and (color_column not in adata.obs.columns and color_column not in adata.var.columns): err_msg = f"'{color_column}' not found in adata.obs or adata.var." raise KeyError(err_msg) - + # Add color column to the kwargs for the scanpy plot if color_column: kwargs['color'] = color_column - + # Plot the t-SNE sc.pl.tsne(adata, ax=ax, **kwargs) - + return fig, ax - - -def _derive_facet_geometry( - n_groups, - facet_ncol=None, - facet_fig_width=None, - facet_fig_height=None, - facet_tick_max_chars=0, - facet_tick_rotation=0.0, - vertical_threshold=3, - default_height=3.2, - default_aspect=1.25, - min_panel_width=1.8, - min_panel_height=1.6, - min_aspect=0.6, - max_aspect=2.0, -): - """Derive FacetGrid geometry from pre-normalized facet layout hints. - - Parameters - ---------- - n_groups : int - Number of facet panels. Expected to be a positive integer supplied by - the grouped histogram path. - facet_ncol : int or None, optional - Requested facet column count. Positive integers are used directly. - ``None`` falls back to automatic column selection. - facet_fig_width, facet_fig_height : float, optional - Optional total figure-size hints. Geometry is derived from these hints only - when both values are present. - facet_tick_max_chars : int, optional - Maximum observed x tick-label length. Expected positive integer. - Used for adjusting default geometry heuristics when explicit figure-size - hints are absent. - ``0`` falls back to the original default geometry without long-label adjustments. - facet_tick_rotation : float, optional - Rotation angle in degrees for x tick labels. Used together with - ``facet_tick_max_chars`` to estimate label burden for default geometry. - vertical_threshold : int, optional - Maximum group count that still prefers a single-column automatic - layout. - default_height : float, optional - Per-facet panel height used when explicit figure-size hints are not - available. - default_aspect : float, optional - Per-facet panel aspect ratio used when explicit figure-size hints are - not available. - min_panel_width, min_panel_height : float, optional - Lower bounds applied to per-panel dimensions when figure-size hints are - converted into FacetGrid geometry. - min_aspect, max_aspect : float, optional - Bounds applied to the derived panel aspect ratio. - - Returns - ------- - dict - Dictionary with keys: - - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; - - ``facet_height``: float, FacetGrid-ready per-panel height in inches; - - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. - - Automatic layout uses one column when ``n_groups <= vertical_threshold`` - and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both - normalized figure-size hints are present, the helper converts total - figure size into per-panel geometry using the derived grid shape, - applies the minimum panel-size guardrails, and clips aspect into the - configured range. When figure-size hints are absent, the helper may - increase default facet height/aspect for long rotated categorical - labels to preserve usable bar area. - """ - - # Derive facet_ncol when not explicitly provided, and clamp to n_groups - if facet_ncol is None: - if n_groups <= vertical_threshold: - facet_ncol = 1 - else: - facet_ncol = int(np.ceil(np.sqrt(n_groups))) - logging.info( - "Automatic facet_ncol selection: %s columns for %s groups " - "(vertical_threshold=%s).", - facet_ncol, - n_groups, - vertical_threshold, - ) - facet_ncol = max(1, min(int(facet_ncol), n_groups)) - - # Use defaults if figure-size hints are not provided - facet_height = default_height - facet_aspect = default_aspect - - # Derive facet geometry from figure-size hints when both are provided - if facet_fig_width is not None and facet_fig_height is not None: - nrow = int(np.ceil(n_groups / facet_ncol)) - panel_width = max(facet_fig_width / facet_ncol, min_panel_width) - panel_height = max(facet_fig_height / nrow, min_panel_height) - facet_height = panel_height - facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) - - elif facet_tick_max_chars and facet_tick_max_chars > 0: - # For default geometry only, expand panel ratio when long rotated - # labels would otherwise dominate the available plotting area. - rotation = float(facet_tick_rotation or 0.0) % 360.0 - rad = np.deg2rad(min(rotation, 180.0)) - rotation_factor = 1.0 + 0.8 * np.sin(rad) - burden = float(facet_tick_max_chars) * rotation_factor - long_label_threshold = 12.0 - - if burden > long_label_threshold: - pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) - facet_height = default_height * (1.0 + 0.35 * pressure) - facet_aspect = float( - np.clip( - # default_aspect * (1.0 + 0.75 * pressure), - default_aspect * (1.0 - 0.05 * pressure), - min_aspect, - max_aspect, - ) - ) - logging.info( - "Automatic facet geometry adjustment for long x tick labels: " - "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", - facet_tick_max_chars, - facet_tick_rotation, - facet_height, - facet_aspect, - ) - - return { - "facet_ncol": facet_ncol, - "facet_height": facet_height, - "facet_aspect": facet_aspect, - } - - + + +def _derive_facet_geometry( + n_groups, + facet_ncol=None, + facet_fig_width=None, + facet_fig_height=None, + facet_tick_max_chars=0, + facet_tick_rotation=0.0, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + min_panel_width=1.8, + min_panel_height=1.6, + min_aspect=0.6, + max_aspect=2.0, +): + """Derive FacetGrid geometry from pre-normalized facet layout hints. + + Parameters + ---------- + n_groups : int + Number of facet panels. Expected to be a positive integer supplied by + the grouped histogram path. + facet_ncol : int or None, optional + Requested facet column count. Positive integers are used directly. + ``None`` falls back to automatic column selection. + facet_fig_width, facet_fig_height : float, optional + Optional total figure-size hints. Geometry is derived from these hints only + when both values are present. + facet_tick_max_chars : int, optional + Maximum observed x tick-label length. Expected positive integer. + Used for adjusting default geometry heuristics when explicit figure-size + hints are absent. + ``0`` falls back to the original default geometry without long-label adjustments. + facet_tick_rotation : float, optional + Rotation angle in degrees for x tick labels. Used together with + ``facet_tick_max_chars`` to estimate label burden for default geometry. + vertical_threshold : int, optional + Maximum group count that still prefers a single-column automatic + layout. + default_height : float, optional + Per-facet panel height used when explicit figure-size hints are not + available. + default_aspect : float, optional + Per-facet panel aspect ratio used when explicit figure-size hints are + not available. + min_panel_width, min_panel_height : float, optional + Lower bounds applied to per-panel dimensions when figure-size hints are + converted into FacetGrid geometry. + min_aspect, max_aspect : float, optional + Bounds applied to the derived panel aspect ratio. + + Returns + ------- + dict + Dictionary with keys: + - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; + - ``facet_height``: float, FacetGrid-ready per-panel height in inches; + - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. + + Automatic layout uses one column when ``n_groups <= vertical_threshold`` + and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both + normalized figure-size hints are present, the helper converts total + figure size into per-panel geometry using the derived grid shape, + applies the minimum panel-size guardrails, and clips aspect into the + configured range. When figure-size hints are absent, the helper may + increase default facet height/aspect for long rotated categorical + labels to preserve usable bar area. + """ + + # Derive facet_ncol when not explicitly provided, and clamp to n_groups + if facet_ncol is None: + if n_groups <= vertical_threshold: + facet_ncol = 1 + else: + facet_ncol = int(np.ceil(np.sqrt(n_groups))) + logging.info( + "Automatic facet_ncol selection: %s columns for %s groups " + "(vertical_threshold=%s).", + facet_ncol, + n_groups, + vertical_threshold, + ) + facet_ncol = max(1, min(int(facet_ncol), n_groups)) + + # Use defaults if figure-size hints are not provided + facet_height = default_height + facet_aspect = default_aspect + + # Derive facet geometry from figure-size hints when both are provided + if facet_fig_width is not None and facet_fig_height is not None: + nrow = int(np.ceil(n_groups / facet_ncol)) + panel_width = max(facet_fig_width / facet_ncol, min_panel_width) + panel_height = max(facet_fig_height / nrow, min_panel_height) + facet_height = panel_height + facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) + + elif facet_tick_max_chars and facet_tick_max_chars > 0: + # For default geometry only, expand panel ratio when long rotated + # labels would otherwise dominate the available plotting area. + rotation = float(facet_tick_rotation or 0.0) % 360.0 + rad = np.deg2rad(min(rotation, 180.0)) + rotation_factor = 1.0 + 0.8 * np.sin(rad) + burden = float(facet_tick_max_chars) * rotation_factor + long_label_threshold = 12.0 + + if burden > long_label_threshold: + pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) + facet_height = default_height * (1.0 + 0.35 * pressure) + facet_aspect = float( + np.clip( + # default_aspect * (1.0 + 0.75 * pressure), + default_aspect * (1.0 - 0.05 * pressure), + min_aspect, + max_aspect, + ) + ) + logging.info( + "Automatic facet geometry adjustment for long x tick labels: " + "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", + facet_tick_max_chars, + facet_tick_rotation, + facet_height, + facet_aspect, + ) + + return { + "facet_ncol": facet_ncol, + "facet_height": facet_height, + "facet_aspect": facet_aspect, + } + + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, - x_log_scale=False, y_log_scale=False, facet=False, **kwargs): + x_log_scale=False, y_log_scale=False, facet=False, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + feature : str, optional Name of continuous feature from adata.X to plot its histogram. - + annotation : str, optional Name of the annotation from adata.obs to plot its histogram. - + layer : str, optional Name of the layer in adata.layers to plot its histogram. - + group_by : str, default None Choose either to group the histogram by another column. - + together : bool, default False If True, and if group_by != None, create one plot combining all groups. If False, create separate histograms for each group. The appearance of combined histograms can be controlled using the - `multiple` and `element` parameters in **kwargs. Separate grouped or - faceted histograms ignore `multiple`. + `multiple` and `element` parameters in **kwargs. Separate grouped or + faceted histograms ignore `multiple`. To control how the histograms are normalized (e.g., to divide the histogram by the number of elements in every group), use the `stat` parameter in **kwargs. For example, set `stat="probability"` to show the relative frequencies of each group. - + ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. - Not supported for grouped-separate (`group_by` with `together=False`) - or facet layouts (`group_by` with `facet=True`). - + Not supported for grouped-separate (`group_by` with `together=False`) + or facet layouts (`group_by` with `facet=True`). + x_log_scale : bool, default False If True, the data will be transformed using np.log1p before plotting, and the x-axis label will be adjusted accordingly. - + y_log_scale : bool, default False If True, the y-axis will be set to log scale. - - facet : bool, default False - If True, group by function outputs facet plots - + + facet : bool, default False + If True, group by function outputs facet plots + **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: - `multiple`: Determines how the subsets of data are displayed - on the same axes. Ignored when `group_by` is used with - `together=False`. Options include: + on the same axes. Ignored when `group_by` is used with + `together=False`. Options include: * "layer": Draws each subset on top of the other without adjustments. * "dodge": Dodges bars for each subset side by side. @@ -603,10 +603,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, * "step": Creates a step line plot without bars. * "poly": Creates a polygon where the bottom edge represents the x-axis and the top edge the histogram's bins. - - `shrink`: Scale bar width relative to each bin's width. - Useful for reducing overlap with `multiple="dodge"`. - - `alpha`: Transparency for histogram artists. - 0 is fully transparent and 1 is fully opaque. + - `shrink`: Scale bar width relative to each bin's width. + Useful for reducing overlap with `multiple="dodge"`. + - `alpha`: Transparency for histogram artists. + 0 is fully transparent and 1 is fully opaque. - `log_scale`: Determines if the data should be plotted on a logarithmic scale. - `stat`: Determines the statistical transformation to use on the data @@ -622,42 +622,42 @@ def histogram(adata, feature=None, annotation=None, layer=None, Can be a number (indicating the number of bins) or a list (indicating bin edges). For example, `bins=10` will create 10 bins, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. - If not provided, or if passed as `None`/`"auto"`/`"none"`, - the binning will be determined automatically using the Rice rule. + If not provided, or if passed as `None`/`"auto"`/`"none"`, + the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. - - When `group_by` is provided, this optional key can be passed via `kwargs`: - - `max_groups`: Controls the group-count guardrail for grouped plots. - Default is 20 when omitted. Pass `"unlimited"` to disable this - guardrail, which may lead to performance issues or unreadable plots - with many groups. - - When `facet=True`, these optional key can be passed via `kwargs` - to customize FacetGrid layout: - - `facet_ncol`: Controls facet column wrapping. - If omitted or passed as `"auto"`, the function uses one column for - small group counts and switches to a compact grid for many groups. - Otherwise, the provided value is used to request the facet column - count. - - `facet_fig_width`: float, intended final figure width in inches. - - `facet_fig_height`: float, intended final figure height in inches. - - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. - + + When `group_by` is provided, this optional key can be passed via `kwargs`: + - `max_groups`: Controls the group-count guardrail for grouped plots. + Default is 20 when omitted. Pass `"unlimited"` to disable this + guardrail, which may lead to performance issues or unreadable plots + with many groups. + + When `facet=True`, these optional key can be passed via `kwargs` + to customize FacetGrid layout: + - `facet_ncol`: Controls facet column wrapping. + If omitted or passed as `"auto"`, the function uses one column for + small group counts and switches to a compact grid for many groups. + Otherwise, the provided value is used to request the facet column + count. + - `facet_fig_width`: float, intended final figure width in inches. + - `facet_fig_height`: float, intended final figure height in inches. + - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. + Returns ------- A dictionary containing the following: fig : matplotlib.figure.Figure The created figure for the plot. - + axs : matplotlib.axes.Axes or list of Axes The Axes object(s) of the histogram plot(s). Returns a single Axes if only one plot is created, otherwise returns a list of Axes. - + df : pandas.DataFrame DataFrame containing the data used for plotting the histogram. - + """ - + # If no feature or annotation is specified, apply default behavior if feature is None and annotation is None: # Default to the first feature in adata.var_names @@ -668,7 +668,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, f"'{feature}'.", UserWarning ) - + # Use utility functions for input validation if layer: check_table(adata, tables=layer) @@ -678,7 +678,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, check_feature(adata, features=feature) if group_by: check_annotation(adata, annotations=group_by) - + # If layer is specified, get the data from that layer if layer: df = pd.DataFrame( @@ -689,15 +689,15 @@ def histogram(adata, feature=None, annotation=None, layer=None, adata.X, index=adata.obs.index, columns=adata.var_names ) layer = 'Original' - + df = pd.concat([df, adata.obs], axis=1) - + if feature and annotation: raise ValueError("Cannot pass both feature and annotation," " choose one.") - + data_column = feature if feature else annotation - + # Check for negative values and apply log1p transformation if # x_log_scale is True if x_log_scale: @@ -708,156 +708,156 @@ def histogram(adata, feature=None, annotation=None, layer=None, x_log_scale = False else: df[data_column] = np.log1p(df[data_column]) - - # If ax is provided, validate input and get figure from it. - # If not, the figure will be created in the plotting branch. + + # If ax is provided, validate input and get figure from it. + # If not, the figure will be created in the plotting branch. if ax is not None: - if group_by and not together: - raise ValueError( - "External ax is only supported for single-axes histogram " - "Please set together=True or remove external ax." - ) + if group_by and not together: + raise ValueError( + "External ax is only supported for single-axes histogram " + "Please set together=True or remove external ax." + ) fig = ax.get_figure() - + axs = [] - + # Prepare the data for plotting plot_data = df.dropna(subset=[data_column]) - + # Bin calculation section # The default bin calculation used by sns.histo take quite # some time to compute for large number of points, # DMAP implemented the Rice rule for bin computation - + def cal_bin_num( num_rows ): bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') - - return (bins) - + + return (bins) + num_rows = plot_data.shape[0] - - # Check if bins is not being passed or set to None or "auto" in kwargs. - # If so, the in house algorithm will compute the number of bins - bins_kwarg = kwargs.get('bins', None) - if isinstance(bins_kwarg, str): - bins_kwarg = bins_kwarg.strip().lower() - if bins_kwarg in {'', 'auto', 'none'}: - bins_kwarg = None - if bins_kwarg is None: + + # Check if bins is not being passed or set to None or "auto" in kwargs. + # If so, the in house algorithm will compute the number of bins + bins_kwarg = kwargs.get('bins', None) + if isinstance(bins_kwarg, str): + bins_kwarg = bins_kwarg.strip().lower() + if bins_kwarg in {'', 'auto', 'none'}: + bins_kwarg = None + if bins_kwarg is None: kwargs['bins'] = cal_bin_num(num_rows) - - # Input validation for facet - if facet: - if group_by is None: - raise ValueError("group_by must be specified when facet=True.") - if together: - raise ValueError("Cannot use together=True with facet=True," - " choose one.") - - def _parse_optional_number( - name, - value, - *, - kind=float, - default=None, - positive=False, - tokens=None, - ): - """Parse an optional numeric hint. - - Returns ``default`` for ``None``, resolves recognized string tokens - before numeric coercion, and optionally enforces finite and positive - values on the parsed result. - """ - if value is None: - return default - if isinstance(value, str): - value = value.strip() - if tokens and value.lower() in tokens: - return tokens[value.lower()] - expected = ( - f'{"positive " if positive else ""}{kind.__name__}' - f'{" or a supported keyword" if tokens else ""}' - ) - if isinstance(value, bool): - raise ValueError(f'{name} must be a {expected}. Received "{value}".') - try: - parsed = kind(value) - except (TypeError, ValueError): - raise ValueError(f'{name} must be a {expected}. Received "{value}".') - if not math.isfinite(parsed): - raise ValueError( - f'{name} must be a finite {kind.__name__}. ' - f'Received "{value}".' - ) - if positive and parsed <= 0: - raise ValueError( - f'{name} must be a positive {kind.__name__}. ' - f'Received "{value}".' - ) - return parsed - - # Parse max_groups with "unlimited" handling and validation. - max_groups = _parse_optional_number( - "max_groups", - kwargs.pop('max_groups', None), - kind=int, - default=20, - positive=True, - tokens={"unlimited": float('inf')}, - ) - - # Parse facet layout hints so they never leak to seaborn. - facet_ncol = _parse_optional_number( - "facet_ncol", - kwargs.pop('facet_ncol', None), - kind=int, - positive=True, - tokens={"": None, "auto": None, "none": None}, - ) - facet_fig_width = kwargs.pop('facet_fig_width', None) - facet_fig_height = kwargs.pop('facet_fig_height', None) - if facet: - facet_fig_width = _parse_optional_number( - "facet_fig_width", - facet_fig_width, - positive=True, - ) - facet_fig_height = _parse_optional_number( - "facet_fig_height", - facet_fig_height, - positive=True, - ) - if (facet_fig_width is None) != (facet_fig_height is None): - raise ValueError( - "Both facet_fig_width and facet_fig_height must be provided together, " - "or both must be left as None." - ) - else: - # If not faceting, ignore any provided figure size hints. - facet_fig_width = None - facet_fig_height = None - facet_tick_rotation = _parse_optional_number( - "facet_tick_rotation", - kwargs.pop('facet_tick_rotation', None), - default=0.0, - ) % 360.0 - + + # Input validation for facet + if facet: + if group_by is None: + raise ValueError("group_by must be specified when facet=True.") + if together: + raise ValueError("Cannot use together=True with facet=True," + " choose one.") + + def _parse_optional_number( + name, + value, + *, + kind=float, + default=None, + positive=False, + tokens=None, + ): + """Parse an optional numeric hint. + + Returns ``default`` for ``None``, resolves recognized string tokens + before numeric coercion, and optionally enforces finite and positive + values on the parsed result. + """ + if value is None: + return default + if isinstance(value, str): + value = value.strip() + if tokens and value.lower() in tokens: + return tokens[value.lower()] + expected = ( + f'{"positive " if positive else ""}{kind.__name__}' + f'{" or a supported keyword" if tokens else ""}' + ) + if isinstance(value, bool): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + try: + parsed = kind(value) + except (TypeError, ValueError): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + if not math.isfinite(parsed): + raise ValueError( + f'{name} must be a finite {kind.__name__}. ' + f'Received "{value}".' + ) + if positive and parsed <= 0: + raise ValueError( + f'{name} must be a positive {kind.__name__}. ' + f'Received "{value}".' + ) + return parsed + + # Parse max_groups with "unlimited" handling and validation. + max_groups = _parse_optional_number( + "max_groups", + kwargs.pop('max_groups', None), + kind=int, + default=20, + positive=True, + tokens={"unlimited": float('inf')}, + ) + + # Parse facet layout hints so they never leak to seaborn. + facet_ncol = _parse_optional_number( + "facet_ncol", + kwargs.pop('facet_ncol', None), + kind=int, + positive=True, + tokens={"": None, "auto": None, "none": None}, + ) + facet_fig_width = kwargs.pop('facet_fig_width', None) + facet_fig_height = kwargs.pop('facet_fig_height', None) + if facet: + facet_fig_width = _parse_optional_number( + "facet_fig_width", + facet_fig_width, + positive=True, + ) + facet_fig_height = _parse_optional_number( + "facet_fig_height", + facet_fig_height, + positive=True, + ) + if (facet_fig_width is None) != (facet_fig_height is None): + raise ValueError( + "Both facet_fig_width and facet_fig_height must be provided together, " + "or both must be left as None." + ) + else: + # If not faceting, ignore any provided figure size hints. + facet_fig_width = None + facet_fig_height = None + facet_tick_rotation = _parse_optional_number( + "facet_tick_rotation", + kwargs.pop('facet_tick_rotation', None), + default=0.0, + ) % 360.0 + # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): """ Compute histogram data for numeric or categorical input. - + Parameters: - data (pd.Series): The input data to be binned. - bins (int or sequence): Number of bins (if numeric) or unique categories (if categorical). - bin_edges (array-like, optional): Predefined bin edges for numeric data. If None, automatic binning is used. - + Returns: - pd.DataFrame: A DataFrame containing the following columns: - `count`: @@ -869,9 +869,9 @@ def calculate_histogram(data, bins, bin_edges=None): - `bin_center`: Center of each bin (for numeric data) or category labels (for categorical data). - + """ - + # Check if the data is numeric or categorical if pd.api.types.is_numeric_dtype(data): if bin_edges is None: @@ -894,274 +894,274 @@ def calculate_histogram(data, bins, bin_edges=None): 'bin_right': counts.index, 'count': counts.values }) - - def build_grouped_histogram_table( - plot_data, data_column, group_by, groups, bins - ): - """Build per-group histogram-bin tables for grouped histogram paths.""" - # Determine shared bins across groups for consistent plotting. - data_series = plot_data[data_column] - if pd.api.types.is_numeric_dtype(data_series): - shared_bins = np.histogram_bin_edges(data_series, bins=bins) - elif isinstance(data_series.dtype, pd.CategoricalDtype): - shared_bins = data_series.cat.categories - else: - shared_bins = pd.Index(pd.unique(data_series.dropna())) - - # Compute histograms for each group using the shared bins. - histograms = [] - for group in groups: - group_data = plot_data.loc[ - plot_data[group_by] == group, data_column - ] - if pd.api.types.is_numeric_dtype(data_series): - group_hist = calculate_histogram( - group_data, bins, bin_edges=shared_bins - ) - else: - # For categorical data, pad missing categories with zero counts - # to ensure consistent plotting. - group_hist = calculate_histogram(group_data, bins) - group_hist = ( - group_hist - .set_index('bin_center') - .reindex(shared_bins) - .rename_axis('bin_center') - .reset_index() - ) - group_hist['count'] = group_hist['count'].fillna(0) - group_hist['bin_left'] = group_hist['bin_center'] - group_hist['bin_right'] = group_hist['bin_center'] - group_hist[group_by] = group - histograms.append(group_hist) - - # Concatenate all group histograms into a single DataFrame for plotting. - hist_data = pd.concat(histograms, ignore_index=True) - return hist_data, shared_bins - - # Function to compute maximum tick label length for categorical data - def compute_max_tick_label_length(data_series): - """Compute maximum tick label length for a categorical data series. - - Parameters - ---------- - data_series : pandas.Series - Categorical data column used to compute maximum tick label length. - - Returns - ------- - int - Maximum number of characters in the tick labels derived from the - unique categories of the input series. - """ - if isinstance(data_series.dtype, pd.CategoricalDtype): - tick_labels = [ - str(label) for label in data_series.cat.categories - ] - else: - tick_labels = [ - str(label) for label in data_series.dropna().unique().tolist() - ] - return max((len(label) for label in tick_labels), default=0) - - # Function to get axis labels based on log scale and stat parameters - def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): - """Resolve x/y axis labels for histogram rendering. - - Parameters - ---------- - data_column : str - Source column used on the x axis. - x_log_scale : bool - Whether x data has log transform semantics. - y_log_scale : bool - Whether y axis is displayed on log scale. - stat : str - Histogram statistic mode (for example, count, density). - - Returns - ------- - tuple[str, str] - Resolved x-axis and y-axis labels. - """ - xlabel = f'log({data_column})' if x_log_scale else data_column - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability', - "proportion": "Proportion", - "percent": "Percent" - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - return xlabel, ylabel - + + def build_grouped_histogram_table( + plot_data, data_column, group_by, groups, bins + ): + """Build per-group histogram-bin tables for grouped histogram paths.""" + # Determine shared bins across groups for consistent plotting. + data_series = plot_data[data_column] + if pd.api.types.is_numeric_dtype(data_series): + shared_bins = np.histogram_bin_edges(data_series, bins=bins) + elif isinstance(data_series.dtype, pd.CategoricalDtype): + shared_bins = data_series.cat.categories + else: + shared_bins = pd.Index(pd.unique(data_series.dropna())) + + # Compute histograms for each group using the shared bins. + histograms = [] + for group in groups: + group_data = plot_data.loc[ + plot_data[group_by] == group, data_column + ] + if pd.api.types.is_numeric_dtype(data_series): + group_hist = calculate_histogram( + group_data, bins, bin_edges=shared_bins + ) + else: + # For categorical data, pad missing categories with zero counts + # to ensure consistent plotting. + group_hist = calculate_histogram(group_data, bins) + group_hist = ( + group_hist + .set_index('bin_center') + .reindex(shared_bins) + .rename_axis('bin_center') + .reset_index() + ) + group_hist['count'] = group_hist['count'].fillna(0) + group_hist['bin_left'] = group_hist['bin_center'] + group_hist['bin_right'] = group_hist['bin_center'] + group_hist[group_by] = group + histograms.append(group_hist) + + # Concatenate all group histograms into a single DataFrame for plotting. + hist_data = pd.concat(histograms, ignore_index=True) + return hist_data, shared_bins + + # Function to compute maximum tick label length for categorical data + def compute_max_tick_label_length(data_series): + """Compute maximum tick label length for a categorical data series. + + Parameters + ---------- + data_series : pandas.Series + Categorical data column used to compute maximum tick label length. + + Returns + ------- + int + Maximum number of characters in the tick labels derived from the + unique categories of the input series. + """ + if isinstance(data_series.dtype, pd.CategoricalDtype): + tick_labels = [ + str(label) for label in data_series.cat.categories + ] + else: + tick_labels = [ + str(label) for label in data_series.dropna().unique().tolist() + ] + return max((len(label) for label in tick_labels), default=0) + + # Function to get axis labels based on log scale and stat parameters + def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): + """Resolve x/y axis labels for histogram rendering. + + Parameters + ---------- + data_column : str + Source column used on the x axis. + x_log_scale : bool + Whether x data has log transform semantics. + y_log_scale : bool + Whether y axis is displayed on log scale. + stat : str + Histogram statistic mode (for example, count, density). + + Returns + ------- + tuple[str, str] + Resolved x-axis and y-axis labels. + """ + xlabel = f'log({data_column})' if x_log_scale else data_column + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability', + "proportion": "Proportion", + "percent": "Percent" + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + return xlabel, ylabel + # Plotting with or without grouping if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) - + if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") - elif n_groups > max_groups: - raise ValueError( - "The number of groups in `group_by` exceeds `max_groups`: " - f"found {n_groups}, threshold {max_groups}.\n" - "Please reduce/bucket groups or use another grouping column, or " - "pass a larger `max_groups`.\n" - "See `kwargs` documentation for more details." - ) - + elif n_groups > max_groups: + raise ValueError( + "The number of groups in `group_by` exceeds `max_groups`: " + f"found {n_groups}, threshold {max_groups}.\n" + "Please reduce/bucket groups or use another grouping column, or " + "pass a larger `max_groups`.\n" + "See `kwargs` documentation for more details." + ) + if together: - if ax is None: - fig, ax = plt.subplots() - - hist_data, shared_bins = build_grouped_histogram_table( - plot_data, - data_column, - group_by, - groups, - bins=kwargs.pop('bins'), - ) + if ax is None: + fig, ax = plt.subplots() + + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) if pd.api.types.is_numeric_dtype(plot_data[data_column]): - kwargs['bins'] = shared_bins.tolist() - + kwargs['bins'] = shared_bins.tolist() + # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - - sns.histplot( - data=hist_data, - x='bin_center', - weights='count', - hue=group_by, - ax=ax, - **kwargs, - ) + + sns.histplot( + data=hist_data, + x='bin_center', + weights='count', + hue=group_by, + ax=ax, + **kwargs, + ) # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) - + else: - # 'multiple' parameter is not applicable - kwargs.pop('multiple', None) - - if not facet: - fig, ax_array = plt.subplots( - n_groups, 1, figsize=(5, 5 * n_groups) - ) - - # Convert a single Axes object to a list - # Ensure ax_array is always iterable - if n_groups == 1: - ax_array = [ax_array] + # 'multiple' parameter is not applicable + kwargs.pop('multiple', None) + + if not facet: + fig, ax_array = plt.subplots( + n_groups, 1, figsize=(5, 5 * n_groups) + ) + + # Convert a single Axes object to a list + # Ensure ax_array is always iterable + if n_groups == 1: + ax_array = [ax_array] else: - ax_array = ax_array.flatten() - - for i, ax_i in enumerate(ax_array): - group_data = plot_data[plot_data[group_by] == - groups[i]][data_column] - hist_data = calculate_histogram(group_data, kwargs['bins']) - - sns.histplot(data=hist_data, x="bin_center", ax=ax_i, - weights='count', **kwargs) - # If plotting feature specify which layer - if feature: - ax_i.set_title(f'{groups[i]} with Layer: {layer}') - else: - ax_i.set_title(f'{groups[i]}') - axs.append(ax_i) - - else: # Facet option - # Compute max label length only when not explicitly provided. - facet_tick_max_chars = 0 - if not pd.api.types.is_numeric_dtype(plot_data[data_column]): - facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) - - # Derive facet geometry based on group count and layout hints - # Returned layout keys: facet_ncol, facet_height, facet_aspect - facet_layout = _derive_facet_geometry( - n_groups=n_groups, - facet_ncol=facet_ncol, - facet_fig_width=facet_fig_width, - facet_fig_height=facet_fig_height, - facet_tick_max_chars=facet_tick_max_chars, - facet_tick_rotation=facet_tick_rotation, - ) - - # Compute histogram data and shared bins for consistent plotting. - # For non-numeric data, shared_bins will be dropped intentionally. - hist_data, shared_bins = build_grouped_histogram_table( - plot_data, - data_column, - group_by, - groups, - bins=kwargs.pop('bins'), - ) - if pd.api.types.is_numeric_dtype(plot_data[data_column]): - kwargs['bins'] = shared_bins.tolist() - - # Create the FacetGrid for the histogram - hist = sns.FacetGrid( - hist_data, - col=group_by, - col_wrap=facet_layout['facet_ncol'], - height=facet_layout['facet_height'], - aspect=facet_layout['facet_aspect'], - sharex=True, - sharey=True, - ) - - # Map the histogram function to the grid - hist.map_dataframe( - sns.histplot, - x='bin_center', - weights='count', - **kwargs, - ) - - # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column - for ax_i in hist.axes.flat: - ax_i.tick_params(axis='x', labelbottom=True) - ax_i.tick_params(axis='y', labelleft=True) - - # Set background color and grid for better readability - for ax_i in hist.axes.flat: - ax_i.set_facecolor('#f2f2f2') - ax_i.grid(True, which='major', axis='both') - - # Titles for each facet - hist.set_titles("{col_name}") - - # Adjust margins for readability across layouts. - hist.figure.subplots_adjust(left=.1, - top=0.9, - bottom=0.12, - hspace=0.35, - wspace=0.2) - - # Pass the figure and axes to the output for further customization - fig = hist.figure - fig.set_size_inches( - facet_fig_width or fig.get_figwidth(), - facet_fig_height or fig.get_figheight(), - ) - axs.extend(hist.axes.flat) - + ax_array = ax_array.flatten() + + for i, ax_i in enumerate(ax_array): + group_data = plot_data[plot_data[group_by] == + groups[i]][data_column] + hist_data = calculate_histogram(group_data, kwargs['bins']) + + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, + weights='count', **kwargs) + # If plotting feature specify which layer + if feature: + ax_i.set_title(f'{groups[i]} with Layer: {layer}') + else: + ax_i.set_title(f'{groups[i]}') + axs.append(ax_i) + + else: # Facet option + # Compute max label length only when not explicitly provided. + facet_tick_max_chars = 0 + if not pd.api.types.is_numeric_dtype(plot_data[data_column]): + facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) + + # Derive facet geometry based on group count and layout hints + # Returned layout keys: facet_ncol, facet_height, facet_aspect + facet_layout = _derive_facet_geometry( + n_groups=n_groups, + facet_ncol=facet_ncol, + facet_fig_width=facet_fig_width, + facet_fig_height=facet_fig_height, + facet_tick_max_chars=facet_tick_max_chars, + facet_tick_rotation=facet_tick_rotation, + ) + + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + kwargs['bins'] = shared_bins.tolist() + + # Create the FacetGrid for the histogram + hist = sns.FacetGrid( + hist_data, + col=group_by, + col_wrap=facet_layout['facet_ncol'], + height=facet_layout['facet_height'], + aspect=facet_layout['facet_aspect'], + sharex=True, + sharey=True, + ) + + # Map the histogram function to the grid + hist.map_dataframe( + sns.histplot, + x='bin_center', + weights='count', + **kwargs, + ) + + # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column + for ax_i in hist.axes.flat: + ax_i.tick_params(axis='x', labelbottom=True) + ax_i.tick_params(axis='y', labelleft=True) + + # Set background color and grid for better readability + for ax_i in hist.axes.flat: + ax_i.set_facecolor('#f2f2f2') + ax_i.grid(True, which='major', axis='both') + + # Titles for each facet + hist.set_titles("{col_name}") + + # Adjust margins for readability across layouts. + hist.figure.subplots_adjust(left=.1, + top=0.9, + bottom=0.12, + hspace=0.35, + wspace=0.2) + + # Pass the figure and axes to the output for further customization + fig = hist.figure + fig.set_size_inches( + facet_fig_width or fig.get_figwidth(), + facet_fig_height or fig.get_figheight(), + ) + axs.extend(hist.axes.flat) + else: - if ax is None: - fig, ax = plt.subplots() - + if ax is None: + fig, ax = plt.subplots() + # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): ax.set_xlim(hist_data['bin_left'].min(), - hist_data['bin_right'].max()) - + hist_data['bin_right'].max()) + sns.histplot( data=hist_data, x='bin_center', @@ -1169,80 +1169,80 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ax=ax, **kwargs ) - + # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) - - stat = kwargs.get('stat', 'count') - xlabel, ylabel = resolve_hist_axis_labels( - data_column=data_column, - x_log_scale=x_log_scale, - y_log_scale=y_log_scale, - stat=stat, - ) - - axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] - for ax in axes: - # Set axis scales if y_log_scale is True - if y_log_scale: - ax.set_yscale('log') - - if facet: - ax.set_xlabel('') - ax.set_ylabel('') - else: - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - - # Set a common x and y label for the entire figure if facet is True - if facet and fig is not None: - fig.supxlabel(xlabel) - fig.supylabel(ylabel) - + + stat = kwargs.get('stat', 'count') + xlabel, ylabel = resolve_hist_axis_labels( + data_column=data_column, + x_log_scale=x_log_scale, + y_log_scale=y_log_scale, + stat=stat, + ) + + axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] + for ax in axes: + # Set axis scales if y_log_scale is True + if y_log_scale: + ax.set_yscale('log') + + if facet: + ax.set_xlabel('') + ax.set_ylabel('') + else: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + # Set a common x and y label for the entire figure if facet is True + if facet and fig is not None: + fig.supxlabel(xlabel) + fig.supylabel(ylabel) + if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} else: return {"fig": fig, "axs": axs, "df": hist_data} - - + + def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + column : str Name of member of adata.obs to plot the histogram. - + layer : str, default None The name of the `adata` layer to use to calculate the mean feature. - + **kwargs: Parameters passed to seaborn heatmap function. - + Returns ------- pandas.DataFrame A dataframe tha has the labels as indexes the mean feature for every marker. - + matplotlib.figure.Figure The figure of the heatmap. - + matplotlib.axes._subplots.AxesSubplot The AsxesSubplot of the heatmap. - + """ features = adata.to_df(layer=layer) labels = adata.obs[column] grouped = pd.concat([features, labels], axis=1).groupby(column) mean_feature = grouped.mean() - + n_rows = len(mean_feature) n_cols = len(mean_feature.columns) fig, ax = plt.subplots(figsize=(n_cols * 1.5, n_rows * 1.5)) @@ -1257,18 +1257,18 @@ def heatmap(adata, column, layer=None, **kwargs): linewidth=.5, annot_kws={"fontsize": 10}, **kwargs) - + ax.tick_params(axis='both', labelsize=25) ax.set_ylabel(column, size=25) - + return mean_feature, fig, ax - + def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=False, standard_scale=None, z_score="annotation", swap_axes=False, rotate_label=False, **kwargs): - + """ Generates a hierarchical clustering heatmap and dendrogram. By default, the dataset is assumed to have features as columns and @@ -1276,7 +1276,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, and for each group, the average expression intensity of each feature (e.g., protein or marker) is computed. The heatmap is plotted using seaborn's clustermap. - + Parameters ---------- adata : anndata.AnnData @@ -1336,7 +1336,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, - `metric` : str The distance metric to use for the hierarchy. Defaults to 'euclidean' in the function. - + Returns ------- mean_intensity : pandas.DataFrame @@ -1349,7 +1349,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, rows and columns. These linkage matrices can be used to generate dendrograms with tools like scipy's dendrogram function. This offers flexibility in customizing and plotting dendrograms as needed. - + Examples -------- import matplotlib.pyplot as plt @@ -1359,7 +1359,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, X = pd.DataFrame([[1, 2], [3, 4]], columns=['gene1', 'gene2']) annotation = pd.DataFrame(['type1', 'type2'], columns=['cell_type']) all_data = anndata.AnnData(X=X, obs=annotation) - + mean_intensity, clustergrid, dendrogram_data = hierarchical_heatmap( all_data, "cell_type", @@ -1369,15 +1369,15 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=True ) - + # To display a standalone dendrogram using the returned linkage matrix: import scipy.cluster.hierarchy as sch import numpy as np import matplotlib.pyplot as plt - + # Convert the linkage data to type double dendro_col_data = np.array(dendrogram_data['col_linkage'], dtype=np.double) - + # Ensure the linkage matrix has at least two dimensions and more than one linkage if dendro_col_data.ndim == 2 and dendro_col_data.shape[0] > 1: @@ -1388,22 +1388,22 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, else: print("Insufficient data to plot a dendrogram.") """ - + # Use utility functions to check inputs check_annotation(adata, annotations=annotation) if features: check_feature(adata, features=features) if layer: check_table(adata, tables=layer) - + # Raise an error if there are any NaN values in the annotation column if adata.obs[annotation].isna().any(): raise ValueError("NaN values found in annotation column.") - + # Convert the observation column to categorical if it's not already if not pd.api.types.is_categorical_dtype(adata.obs[annotation]): adata.obs[annotation] = adata.obs[annotation].astype('category') - + # Calculate mean intensity if layer: intensities = pd.DataFrame( @@ -1413,25 +1413,25 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, ) else: intensities = adata.to_df() - + labels = adata.obs[annotation] grouped = pd.concat([intensities, labels], axis=1).groupby(annotation) mean_intensity = grouped.mean() - + # If swap_axes is True, transpose the mean_intensity if swap_axes: mean_intensity = mean_intensity.T - + # Map z_score based on user's input and the state of swap_axes if z_score == "annotation": z_score = 0 if not swap_axes else 1 elif z_score == "feature": z_score = 1 if not swap_axes else 0 - + # Subset the mean_intensity DataFrame based on selected features if features is not None and len(features) > 0: mean_intensity = mean_intensity.loc[features] - + # Determine clustering behavior based on swap_axes if swap_axes: row_cluster = cluster_feature # Rows are features @@ -1439,7 +1439,7 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, else: row_cluster = cluster_annotations # Rows are annotations col_cluster = cluster_feature # Columns are features - + # Use seaborn's clustermap for hierarchical clustering and # heatmap visualization. clustergrid = sns.clustermap( @@ -1452,29 +1452,29 @@ def hierarchical_heatmap(adata, annotation, features=None, layer=None, col_cluster=col_cluster, **kwargs ) - + # Rotate x-axis tick labels if rotate_label is True if rotate_label: plt.setp(clustergrid.ax_heatmap.get_xticklabels(), rotation=45) - + # Extract the dendrogram data for return dendro_row_data = None dendro_col_data = None - + if clustergrid.dendrogram_row: dendro_row_data = clustergrid.dendrogram_row.linkage - + if clustergrid.dendrogram_col: dendro_col_data = clustergrid.dendrogram_col.linkage - + # Define the dendrogram_data dictionary dendrogram_data = { 'row_linkage': dendro_row_data, 'col_linkage': dendro_col_data } - + return mean_intensity, clustergrid, dendrogram_data - + def threshold_heatmap( adata, feature_cutoffs, annotation, layer=None, swap_axes=False, **kwargs @@ -1482,7 +1482,7 @@ def threshold_heatmap( """ Creates a heatmap for each feature, categorizing intensities into low, medium, and high based on provided cutoffs. - + Parameters ---------- adata : anndata.AnnData @@ -1501,7 +1501,7 @@ def threshold_heatmap( If True, swaps the axes of the heatmap. **kwargs : keyword arguments Additional keyword arguments to pass to scanpy's heatmap function. - + Returns ------- Dictionary of :class:`~matplotlib.axes.Axes` @@ -1511,22 +1511,22 @@ def threshold_heatmap( Potential Keys includes: 'groupby_ax', 'dendrogram_ax', and 'gene_groups_ax'. """ - + # Use utility functions for input validation check_table(adata, tables=layer) check_annotation(adata, annotations=annotation) if feature_cutoffs: check_feature(adata, features=list(feature_cutoffs.keys())) - + # Assert annotation is a string if not isinstance(annotation, str): err_type = type(annotation).__name__ err_msg = (f'Annotation should be string. Got {err_type}.') raise TypeError(err_msg) - + if not isinstance(feature_cutoffs, dict): raise TypeError("feature_cutoffs should be a dictionary.") - + for key, value in feature_cutoffs.items(): if not (isinstance(value, tuple) and len(value) == 2): raise ValueError( @@ -1537,13 +1537,13 @@ def threshold_heatmap( raise ValueError(f"Low cutoff for {key} should not be NaN.") if math.isnan(value[1]): raise ValueError(f"High cutoff for {key} should not be NaN.") - + adata.uns['feature_cutoffs'] = feature_cutoffs - + intensity_df = pd.DataFrame( index=adata.obs_names, columns=feature_cutoffs.keys() ) - + for feature, cutoffs in feature_cutoffs.items(): low_cutoff, high_cutoff = cutoffs feature_values = ( @@ -1554,17 +1554,17 @@ def threshold_heatmap( intensity_df.loc[(feature_values > low_cutoff) & (feature_values <= high_cutoff), feature] = 1 intensity_df.loc[feature_values > high_cutoff, feature] = 2 - + intensity_df = intensity_df.astype(int) adata.layers["intensity"] = intensity_df.to_numpy() adata.obs[annotation] = adata.obs[annotation].astype('category') - + color_map = {0: (0/255, 0/255, 139/255), 1: 'green', 2: 'yellow'} colors = [color_map[i] for i in range(3)] cmap = ListedColormap(colors) - + norm = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], cmap.N) - + heatmap_plot = sc.pl.heatmap( adata, var_names=intensity_df.columns, @@ -1577,18 +1577,18 @@ def threshold_heatmap( swap_axes=swap_axes, **kwargs ) - + # Print the keys of the heatmap_plot dictionary print("Keys of heatmap_plot:", heatmap_plot.keys()) - + # Get the main heatmap axis from the available keys heatmap_ax = heatmap_plot.get('heatmap_ax') - + # If 'heatmap_ax' key does not exist, access the first axis available if heatmap_ax is None: heatmap_ax = next(iter(heatmap_plot.values())) print("Heatmap Axes:", heatmap_ax) - + # Find the colorbar associated with the heatmap cbar = None for child in heatmap_ax.get_children(): @@ -1599,7 +1599,7 @@ def threshold_heatmap( print("No colorbar found in the plot.") return print("Colorbar:", cbar) - + new_ticks = [0, 1, 2] new_labels = ['Low', 'Medium', 'High'] cbar.set_ticks(new_ticks) @@ -1608,9 +1608,9 @@ def threshold_heatmap( cbar.ax.set_position( [pos_heatmap.x1 + 0.02, pos_heatmap.y0, 0.02, pos_heatmap.height] ) - + return heatmap_plot - + def spatial_plot( adata, @@ -1630,7 +1630,7 @@ def spatial_plot( ---------- adata : anndata.AnnData The AnnData object containing target feature and spatial coordinates. - + spot_size : int The size of spot on the spatial plot. alpha : float @@ -1658,7 +1658,7 @@ def spatial_plot( ------- Single or a list of class:`~matplotlib.axes.Axes`. """ - + err_msg_layer = "The 'layer' parameter must be a string, " + \ f"got {str(type(layer))}" err_msg_feature = "The 'feature' parameter must be a string, " + \ @@ -1681,86 +1681,86 @@ def spatial_plot( f"got {str(type(vmax))}" err_msg_ax = "The 'ax' parameter must be an instance " + \ f"of matplotlib.axes.Axes, got {str(type(ax))}" - + if adata is None: raise ValueError("The input dataset must not be None.") - + if not isinstance(adata, anndata.AnnData): err_msg_adata = "The 'adata' parameter must be an " + \ f"instance of anndata.AnnData, got {str(type(adata))}." raise ValueError(err_msg_adata) - + if layer is not None and not isinstance(layer, str): raise ValueError(err_msg_layer) - + if layer is not None and layer not in adata.layers.keys(): err_msg_layer_exist = f"Layer {layer} does not exists, " + \ f"available layers are {str(adata.layers.keys())}" raise ValueError(err_msg_layer_exist) - + if feature is not None and not isinstance(feature, str): raise ValueError(err_msg_feature) - + if annotation is not None and not isinstance(annotation, str): raise ValueError(err_msg_annotation) - + if annotation is not None and feature is not None: raise ValueError(err_msg_feat_annotation_coe) - + if annotation is None and feature is None: raise ValueError(err_msg_feat_annotation_non) - + if 'spatial' not in adata.obsm_keys(): err_msg = "Spatial coordinates not found in the 'obsm' attribute." raise ValueError(err_msg) - + # Extract annotation name annotation_names = adata.obs.columns.tolist() annotation_names_str = ", ".join(annotation_names) - + if annotation is not None and annotation not in annotation_names: error_text = f'The annotation "{annotation}"' + \ 'not found in the dataset.' + \ f" Existing annotations are: {annotation_names_str}" raise ValueError(error_text) - + # Extract feature name if layer is None: layer_process = adata.X else: layer_process = adata.layers[layer] - + feature_names = adata.var_names.tolist() - + if feature is not None and feature not in feature_names: error_text = f"Feature {feature} not found," + \ " please check the sample metadata." raise ValueError(error_text) - + if not isinstance(spot_size, int): raise ValueError(err_msg_spot_size) - + if not isinstance(alpha, float): raise ValueError(err_msg_alpha_type) - + if not (0 <= alpha <= 1): raise ValueError(err_msg_alpha_value) - + if vmin != -999 and not ( isinstance(vmin, float) or isinstance(vmin, int) ): raise ValueError(err_msg_vmin) - + if vmax != -999 and not ( isinstance(vmax, float) or isinstance(vmax, int) ): raise ValueError(err_msg_vmax) - + if ax is not None and not isinstance(ax, plt.Axes): raise ValueError(err_msg_ax) - + if feature is not None: - + feature_index = feature_names.index(feature) feature_annotation = feature + "spatial_plot" if vmin == -999: @@ -1773,11 +1773,11 @@ def spatial_plot( color_region = annotation vmin = None vmax = None - + if ax is None: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) - + ax = sc.pl.spatial( adata=adata, layer=layer, @@ -1789,9 +1789,9 @@ def spatial_plot( ax=ax, show=False, **kwargs) - + return ax - + def boxplot(adata, annotation=None, second_annotation=None, layer=None, ax=None, features=None, log_scale=False, **kwargs): @@ -1799,32 +1799,32 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, Create a boxplot visualization of the features in the passed adata object. This function offers flexibility in how the boxplots are displayed, based on the arguments provided. - + Parameters ---------- adata : anndata.AnnData The AnnData object. - + annotation : str, optional Annotation to determine if separate plots are needed for every label. - + second_annotation : str, optional Second annotation to further divide the data. - + layer : str, optional The name of the matrix layer to use. If not provided, uses the main data matrix adata.X. - + ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. - + features : list, optional List of feature names to be plotted. If not provided, all features will be plotted. - + log_scale : bool, optional If True, the Y-axis will be in log scale. Default is False. - + **kwargs Additional arguments to pass to seaborn.boxplot. Key arguments include: @@ -1837,7 +1837,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ------- fig, ax : matplotlib.figure.Figure, matplotlib.axes.Axes The created figure and axes for the plot. - + Examples -------- - Multiple features boxplot: boxplot(adata, features=['GeneA','GeneB']) @@ -1848,7 +1848,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, - Nested grouping by two annotations: boxplot(adata, features=['GeneA'], annotation='cell_type', second_annotation='treatment') """ - + # Use utility functions to check inputs print("Calculating Box Plot...") if layer: @@ -1859,75 +1859,75 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, check_annotation(adata, annotations=second_annotation) if features: check_feature(adata, features=features) - + if 'orient' not in kwargs: kwargs['orient'] = 'v' - + if kwargs['orient'] != 'v': v_orient = False else: v_orient = True - + # Validate ax instance if ax and not isinstance(ax, plt.Axes): raise TypeError("Input 'ax' must be a matplotlib.axes.Axes object.") - + # Use the specified layer if provided if layer: data_matrix = adata.layers[layer] else: data_matrix = adata.X - + # Create a DataFrame from the data matrix with features as columns df = pd.DataFrame(data_matrix, columns=adata.var_names) - + # Add annotations to the DataFrame if provided if annotation: df[annotation] = adata.obs[annotation].values if second_annotation: df[second_annotation] = adata.obs[second_annotation].values - + # If features is None, set it to all available features if features is None: features = adata.var_names.tolist() - + df = df[ features + ([annotation] if annotation else []) + ([second_annotation] if second_annotation else []) ] - + # Check for negative values if log_scale and (df[features] < 0).any().any(): print( "There are negative values in this data, disabling the log scale." ) log_scale = False - + # Apply log1p transformation if log_scale is True if log_scale: df[features] = np.log1p(df[features]) - + # Create the plot if ax: fig = ax.get_figure() else: fig, ax = plt.subplots(figsize=(10, 5)) - + # Plotting logic based on provided annotations if annotation and second_annotation: if v_orient: sns.boxplot(data=df, y=features[0], x=annotation, hue=second_annotation, ax=ax, **kwargs) - + else: sns.boxplot(data=df, y=annotation, x=features[0], hue=second_annotation, ax=ax, **kwargs) - + title_str = f"Nested Grouping by {annotation} and {second_annotation}" - + ax.set_title(title_str) - + elif annotation: if len(features) > 1: # Reshape the dataframe to long format for visualization @@ -1947,7 +1947,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, sns.boxplot(data=df, x=features[0], y=annotation, ax=ax, **kwargs) ax.set_title(f"Grouped by {annotation}") - + else: if len(features) > 1: if v_orient: @@ -1967,7 +1967,7 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ax.set_yticks([0]) # Set a single tick for the single feature ax.set_yticklabels([features[0]]) # Set the label for the tick ax.set_title("Single Boxplot") - + # Set x and y-axis labels if v_orient: xlabel = annotation if annotation else 'Feature' @@ -1979,12 +1979,12 @@ def boxplot(adata, annotation=None, second_annotation=None, layer=None, ylabel = annotation if annotation else 'Feature' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) - + plt.xticks(rotation=90) plt.tight_layout() - + return fig, ax, df - + def boxplot_interactive( adata, @@ -2007,59 +2007,59 @@ def boxplot_interactive( ): """ Generate a boxplot for given features from an AnnData object. - + This function visualizes the distribution of gene expression (or other features) across different annotations in the provided data. It can handle various options such as log-transformation, feature selection, and handling of outliers. - + Parameters ----------- adata : AnnData An AnnData object containing the data to plot. The expression matrix is accessed via `adata.X` or `adata.layers[layer]`, and annotations are taken from `adata.obs`. - + annotation : str, optional The name of the annotation column (e.g., cell type or sample condition) from `adata.obs` used to group the features. If `None`, no grouping is applied. - + layer : str, optional The name of the layer from `adata.layers` to use. If `None`, `adata.X` is used. - + ax : plotly.graph_objects.Figure, optional The figure to plot the boxplot onto. If `None`, a new figure is created. - + features : list of str, optional The list of features (genes) to plot. If `None`, all features are included. - + showfliers : {None, "downsample", "all"}, default = None If 'all', all outliers are displayed in the boxplot. If 'downsample', when num outliers is >10k, they are downsampled to 10% of the original count. If None, outliers are hidden. - + log_scale : bool, default=False If True, the log1p transformation is applied to the features before plotting. This option is disabled if negative values are found in the features. - + orient : {"v", "h"}, default="v" The orientation of the boxplots: "v" for vertical, "h" for horizontal. - + figure_width : int, optional Width of the figure in inches. Default is 3.2. - + figure_height : int, optional Height of the figure in inches. Default is 2. - + figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. - + defined_color_map : str, optional Key in 'adata.uns' holding a pre-computed color dictionary. Falls back to automatic generation from 'annotation' values. @@ -2072,7 +2072,7 @@ def boxplot_interactive( future enhancements. Default is None. **kwargs : dict Additional arguments for seaborn figure-level functions. - + Returns ------- A dictionary containing the following keys: @@ -2081,15 +2081,15 @@ def boxplot_interactive( - If `figure_type` is "static": A base64-encoded PNG image string - If `figure_type` is "interactive": A Plotly figure object - + df : pd.DataFrame A DataFrame containing the features and their corresponding values. - + metrics : pd.DataFrame A DataFrame containing the computed boxplot metrics (if `return_metrics` is True). """ - + def boxplot_from_statistics( summary_stats: pd.DataFrame, cmap: dict, @@ -2105,12 +2105,12 @@ def boxplot_from_statistics( ): """ Generate a boxplot from the provided summary statistics DataFrame. - + This function visualizes a set of summary statistics (e.g., quartiles, mean) as a boxplot. It supports grouping the data by a given annotation and allows customization of orientation, displaying outliers, and interactive plotting. - + Parameters ---------- summary_stats : pd.DataFrame @@ -2118,49 +2118,49 @@ def boxplot_from_statistics( plot. It should include columns like 'marker', 'q1', 'med', 'q3', 'whislo', 'whishi', and 'mean'. Optionally, it may also contain an annotation column used for grouping. - + cmap : dict A dictionary mapping annotation/feature values to color strings (hex, rgb/rgba, hsl/hsla, hsv/hsva, or CSS). - + annotation : str, optional The column name in `summary_stats` used to group the data by specific categories (e.g., cell type, condition). If `None`, no grouping is applied. - + ax : matplotlib.axes.Axes or plotly.graph_objects.Figure, optional A figure or axes to plot onto. If None, a new Plotly figure is created. - + showfliers : {None, "downsample", "all"}, default = None If 'all', all outliers are displayed in the boxplot. If 'downsample', when num outliers is >10k, they are downsampled to 10% of the original count. If None, outliers are hidden. - + log_scale : bool, optional, default=False If True, the log1p transformation is applied to the features before plotting. This option is disabled if negative values are found in the features. - + orient : {"v", "h"}, default="v" The orientation of the boxplot: 'v' for vertical and 'h' for horizontal. - + figure_width : int, optional Width of the figure in inches. Default is 3.2. - + figure_height : int, optional Height of the figure in inches. Default is 2. - + figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. - + Returns ------- fig : plotly.graph_objects.Figure The Plotly figure containing the generated boxplot. - + Notes ----- - The function uses the `plotly` library for visualization, allowing @@ -2170,36 +2170,36 @@ def boxplot_from_statistics( - The boxplot will display whiskers, quartiles, and the mean. Outliers are controlled by the `showfliers` parameter. """ - + # Initialize the figure: if 'ax' is provided, use it, otherwise create # a new Plotly figure if ax: fig = ax else: fig = go.Figure() - + # Get unique features (markers) from the summary statistics unique_features = summary_stats["marker"].unique() - + # Create comma seperated list for features in the plot title # If there are >3 unique features, use 'Multiple Features' in the title if len(unique_features) < 4: plot_title = f"{', '.join(unique_features[0:])}" else: plot_title = 'Multiple Features' - + if annotation: unique_annotations = summary_stats[annotation].unique() - + plot_title += f" grouped by {annotation}" - + # Empty outlier lists cause issues with plotly, # so replace them with [None] if showfliers: summary_stats["fliers"] = summary_stats["fliers"].apply( lambda x: [None] if len(x) == 0 else x ) - + # Set up the orientation of the plot data & axis-labels if orient == "h": x_data = "fliers" @@ -2211,7 +2211,7 @@ def boxplot_from_statistics( y_data = "fliers" x_axis_label = annotation if annotation else "feature value" y_axis_label = "log(Intensity)" if log_scale else "Intensity" - + # If annotation is provided, group the data # and create boxplots for each group if annotation: @@ -2222,7 +2222,7 @@ def boxplot_from_statistics( grouped_data[annotation_value] = summary_stats[ summary_stats[annotation] == annotation_value ].to_dict(orient="list") - + # Add a boxplot trace for each annotation value for annotation_value, data in grouped_data.items(): if orient == "h": @@ -2231,7 +2231,7 @@ def boxplot_from_statistics( else: y = data[y_data] if showfliers else None x = data[x_data] - + fig.add_trace( go.Box( name=annotation_value, @@ -2259,14 +2259,14 @@ def boxplot_from_statistics( unique_annotations = unique_annotations[ unique_annotations != annotation_value ] - + # Adjust layout to group the boxplots by annotation fig.update_layout(boxmode="group") else: # If no annotation, create a boxplot # for each unique feature (marker) stats_dict = summary_stats.to_dict(orient="list") - + for i, marker_value in enumerate(stats_dict["marker"]): if orient == "h": y = [stats_dict[y_data][i]] @@ -2274,7 +2274,7 @@ def boxplot_from_statistics( else: y = [stats_dict[y_data][i], [None]] if showfliers else None x = [stats_dict[x_data][i]] - + # Note: adding None to the x or y data to ensure # the outliers are displayed correctly fig.add_trace( @@ -2298,7 +2298,7 @@ def boxplot_from_statistics( **kwargs ) ) - + # Final layout adjustments for the plot title, axis labels, and size fig.update_layout( title=plot_title, @@ -2307,13 +2307,13 @@ def boxplot_from_statistics( height=int(figure_height * figure_dpi), width=int(figure_width * figure_dpi), ) - + return fig - + ##################### # Main Code Block # ##################### - + logging.info("Calculating Box Plot...") if layer: check_table(adata, tables=layer) @@ -2321,55 +2321,55 @@ def boxplot_from_statistics( check_annotation(adata, annotations=annotation) if features: check_feature(adata, features=features) - + if ax and not isinstance(ax, plt.Figure): raise TypeError("Input 'ax' must be a plotly.Figure object.") - + if showfliers not in ("all", "downsample", None): raise ValueError( ("showfliers must be one of 'all', 'downsample', or None."), (f" Got {showfliers}."), ) - + if figure_type not in ("interactive", "static", "png"): raise ValueError( (f"figure_type must be one of 'interactive', 'static', or 'png'."), (f" Got {figure_type}."), ) - + # Extract data from the specified layer or the default matrix (adata.X) if layer: data_matrix = adata.layers[layer] else: data_matrix = adata.X - + # Convert the data matrix into a DataFrame with # appropriate column names (features) df = pd.DataFrame(data_matrix, columns=adata.var_names) - + # Add annotation column to the DataFrame if provided if annotation: df[annotation] = adata.obs[annotation].values - + # If no specific features are provided, use all available features if features is None: features = adata.var_names.tolist() - + # Filter the DataFrame to include only the # selected features and the annotation df = df[features + ([annotation] if annotation else [])] - + # Check for negative values if log scale is requested if log_scale and (df[features] < 0).any().any(): print( "There are negative values in this data, disabling the log scale." ) log_scale = False - + # Apply log1p transformation if log_scale is True if log_scale: df[features] = np.log1p(df[features]) - + start_time = time.time() # Compute the summary statistics required for the boxplot metrics = compute_boxplot_metrics( @@ -2379,7 +2379,7 @@ def boxplot_from_statistics( "Time taken to compute boxplot metrics: %f seconds", time.time() - start_time ) - + # Get the colormap for the annotation if defined_color_map: cmap = get_defined_color_map(adata) @@ -2397,7 +2397,7 @@ def boxplot_from_statistics( color_map=feature_colorscale, return_dict=True, ) - + start_time = time.time() # Generate the boxplot figure from the summary statistics fig = boxplot_from_statistics( @@ -2413,7 +2413,7 @@ def boxplot_from_statistics( figure_dpi=figure_dpi, **kwargs, ) - + # Prepare the base image or figure return value if figure_type == "interactive": plot = fig @@ -2442,19 +2442,19 @@ def boxplot_from_statistics( 'legend_itemdoubleclick': False } plot = fig.update_layout(**config) - + logging.info( "Time taken to generate boxplot: %f seconds", time.time() - start_time ) - + result = {"fig": plot, "df": df} # Determine if metrics included based on return_metrics flag if return_metrics: result["metrics"] = metrics - + return result - + def interactive_spatial_plot( adata, @@ -2476,11 +2476,11 @@ def interactive_spatial_plot( cmax=None, **kwargs ): - + """ Create an interactive scatter plot for spatial data using provided annotations. - + Parameters ---------- adata : AnnData @@ -2532,55 +2532,55 @@ def interactive_spatial_plot( Default is None. **kwargs Additional keyword arguments for customization. - + Returns ------- list of dict A list of dictionaries, each containing the following keys: - "image_name": str, the name of the generated image. - "image_object": Plotly Figure object. - + Notes ----- This function is tailored for spatial single-cell data and expects the AnnData object to have spatial coordinates in its `.obsm` attribute under the 'spatial' key. """ - + if annotations is None and feature is None: raise ValueError( "At least one of the 'annotations' or 'feature' parameters " + \ "must be provided." ) - + if annotations is not None: if not isinstance(annotations, list): annotations = [annotations] - + for annotation in annotations: check_annotation( adata, annotations=annotation ) - + if feature is not None: check_feature( adata, features=feature ) - + if layer is not None: check_table( adata, tables=layer ) - + check_table( adata, tables='spatial', associated_table=True ) - + def prepare_spatial_dataframe( adata, annotations=None, @@ -2588,13 +2588,13 @@ def prepare_spatial_dataframe( layer=None): """ Prepare a DataFrame for spatial plotting from an AnnData object. - + If 'annotations' is provided (a string or list of strings), the returned DataFrame will contain the X,Y coordinates and one column per annotation. If 'feature' is provided (and annotations is None), a single 'color' column is created from adata.layers[layer] (if provided) or adata.X. - + Parameters ---------- adata : anndata.AnnData @@ -2605,13 +2605,13 @@ def prepare_spatial_dataframe( Continuous feature name in adata.var_names for coloring. layer : str, optional Layer to use for feature values if feature is provided. - + Returns ------- df : pandas.DataFrame DataFrame with columns 'X', 'Y' and each annotation column (or a 'color' column for continuous feature). - + Raises ------ ValueError @@ -2621,7 +2621,7 @@ def prepare_spatial_dataframe( xcoord = [coord[0] for coord in spatial] ycoord = [coord[1] for coord in spatial] df = pd.DataFrame({'X': xcoord, 'Y': ycoord}) - + if annotations is not None: if isinstance(annotations, str): annotations = [annotations] @@ -2635,7 +2635,7 @@ def prepare_spatial_dataframe( raise ValueError( "Either 'annotations' or 'feature' must be provided.") return df - + def main_figure_generation( spatial_df, annotations=None, @@ -2656,7 +2656,7 @@ def main_figure_generation( This function generates the main interactive plot using Plotly that contains the spatial scatter plot with annotations and image configuration. - + Parameters ---------- spatial_df : pandas.DataFrame @@ -2685,30 +2685,30 @@ def main_figure_generation( Font size for text in the plot. Default is 12. title : str, optional Title of the image. Default is "interactive_spatial_plot". - + Returns ------- plotly.graph_objs._figure.Figure The generated interactive Plotly figure. """ - + xcoord = spatial_df['X'] ycoord = spatial_df['Y'] - + min_x, max_x = min(xcoord), max(xcoord) min_y, max_y = min(ycoord), max(ycoord) dx = max_x - min_x - + dy = max_y - min_y - + min_x_range = min_x - 0.05 * dx max_x_range = max_x + 0.05 * dx min_y_range = min_y - 0.05 * dy max_y_range = max_y + 0.05 * dy - + width_px = int(figure_width * figure_dpi) height_px = int(figure_height * figure_dpi) - + # Define partial for scatter traces with common parameters scatter_partial = partial( px.scatter, @@ -2717,7 +2717,7 @@ def main_figure_generation( render_mode="webgl", **kwargs ) - + # Helper function to create a scatter trace for features # as it needs a continuous color scale. # in my experience, px.scatter does not work well with @@ -2741,11 +2741,11 @@ def create_scatter_trace(df, feature, colorscale): text=df[feature], **kwargs ) - + # The annotation trace creates a dummy point # so that the label of that annotion is shown in the legend def create_annotation_trace(filtered, obs): - + # add one extra point just close to the first point trace = px.scatter( x=[filtered['X'].iloc[0]-0.1], @@ -2764,14 +2764,14 @@ def create_annotation_trace(filtered, obs): name=f'{obs}' ) return trace - + main_fig = go.Figure() - + if annotations is not None: # Loop over all annotation and add annotation dummy point # and data points to the figure for obs in annotations: - + spatial_df[obs].fillna("no_label", inplace=True) filtered = spatial_df # Create and add annotation trace using the helper function @@ -2785,19 +2785,19 @@ def create_annotation_trace(filtered, obs): hover_data=[obs], color_discrete_map=color_mapping, ).data) - + elif feature is not None: - + main_fig.add_trace( create_scatter_trace(spatial_df, feature, colorscale) ) - + else: raise ValueError( "No plot is generated." " Either 'annotations' or 'feature' must be provided." ) - + if annotations is not None: # Set the hover template to show x, y and annotation # This is needed to show the correct label when @@ -2806,7 +2806,7 @@ def create_annotation_trace(filtered, obs): elif feature is not None: # it is already set in the create_scatter_trace function hovertemplate = None - + main_fig.update_traces( mode='markers', marker=dict( @@ -2815,7 +2815,7 @@ def create_annotation_trace(filtered, obs): ), hovertemplate=hovertemplate ) - + main_fig.update_layout( width=width_px, height=height_px, @@ -2872,21 +2872,21 @@ def create_annotation_trace(filtered, obs): }, margin=dict(l=5, r=5, t=font_size*2, b=5) ) - + if reverse_y_axis: main_fig.update_layout(yaxis=dict(autorange="reversed")) - + return { "image_name": f"{spell_out_special_characters(title)}.html", "image_object": main_fig } - + ##################### # Main Code Block ## ##################### - + from functools import partial - + # Set the discrete or continuous color parameters color_dict = None colorscale = None @@ -2905,7 +2905,7 @@ def create_annotation_trace(filtered, obs): f'Colored by "{feature}", ' f'table: "{layer if layer else "Original"}"' ) - + # Create the partial function with the common keyword arguments directly plot_main = partial( main_figure_generation, @@ -2921,18 +2921,18 @@ def create_annotation_trace(filtered, obs): font_size=font_size, **kwargs ) - + results = [] - + if stratify_by is not None: # Check if the stratification column exists in the data check_annotation(adata, annotations=stratify_by) unique_stratification_values = adata.obs[stratify_by].unique() - + for strat_value in unique_stratification_values: condition = adata.obs[stratify_by] == strat_value title_str = f"Subsetting {stratify_by}: {strat_value}" - + indices = np.where(condition)[0] print(f"number of cells in the region: {len(adata.obsm['spatial'][indices])}") adata_subset = select_values( @@ -2940,7 +2940,7 @@ def create_annotation_trace(filtered, obs): annotation=stratify_by, values=strat_value ) - + spatial_df = prepare_spatial_dataframe( adata_subset, annotations=annotations, @@ -2963,16 +2963,16 @@ def create_annotation_trace(filtered, obs): feature=feature, layer=layer ) - + # For non-stratified case, pass extra parameters if needed result = plot_main( spatial_df, title=title_str ) results.append(result) - + return results - + def sankey_plot( adata: anndata.AnnData, @@ -2989,7 +2989,7 @@ def sankey_plot( source annotation, and tab20c for target annotation. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html - + Parameters ---------- adata : anndata.AnnData @@ -3007,13 +3007,13 @@ def sankey_plot( prefix : bool, optional Whether to prefix the target labels with the source labels. Defaults to True. - + Returns ------- plotly.graph_objs._figure.Figure The generated Sankey plot. """ - + label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, @@ -3024,11 +3024,11 @@ def sankey_plot( source_labels = label_relations["source"].unique().tolist() target_labels = label_relations["target"].unique().tolist() all_labels = source_labels + target_labels - + source_label_colors = color_mapping(source_labels, source_color_map) target_label_colors = color_mapping(target_labels, target_color_map) label_colors = source_label_colors + target_label_colors - + # Create a dictionary to map labels to indices label_to_index = { label: index for index, label in enumerate(all_labels)} @@ -3041,7 +3041,7 @@ def sankey_plot( target_indices = [] values = [] link_colors = [] - + # For each row in label_relations, add the source index, target index, # and count to the respective lists for _, row in label_relations.iterrows(): @@ -3049,7 +3049,7 @@ def sankey_plot( target_indices.append(label_to_index[row['target']]) values.append(row['count']) link_colors.append(color_to_map[row['source']]) - + # Generate Sankey diagram # Calculate the x-coordinate for each label fig = go.Figure(go.Sankey( @@ -3073,7 +3073,7 @@ def sankey_plot( size=sankey_font ) )) - + fig.data[0].link.customdata = label_relations[ ['percentage_source', 'percentage_target'] ] @@ -3083,7 +3083,7 @@ def sankey_plot( 'Count: %{value}' ) fig.data[0].link.hovertemplate = hovertemplate - + # Customize the Sankey diagram layout fig.update_layout( title_text=( @@ -3096,15 +3096,15 @@ def sankey_plot( color="black" # Set the title font color ) ) - + fig.update_layout(margin=dict( l=10, r=10, t=sankey_font * 3, b=sankey_font)) - + return fig - + def relational_heatmap( adata: anndata.AnnData, @@ -3118,7 +3118,7 @@ def relational_heatmap( The color map refers to matplotlib color maps, default is mint. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html - + Parameters ---------- adata : anndata.AnnData @@ -3131,7 +3131,7 @@ def relational_heatmap( The color map to use for the relational heatmap. Default is mint. **kwargs : dict, optional Additional keyword arguments. For example, you can pass font_size=12.0. - + Returns ------- dict @@ -3150,41 +3150,41 @@ def relational_heatmap( # Default font size font_size = kwargs.get('font_size', 12.0) prefix = kwargs.get('prefix', True) - + # Get the relationship between source and target annotations - + label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, target_annotation=target_annotation, prefix=prefix ) - + # Pivot the data to create a matrix for the heatmap heatmap_matrix = label_relations.pivot( index='source', columns='target', values='percentage_source' ) - + heatmap_matrix = heatmap_matrix.fillna(0) - + x = list(heatmap_matrix.columns) y = list(heatmap_matrix.index) - + # Create text labels for the heatmap label_relations['text_label'] = [ '{}%'.format(val) for val in label_relations["percentage_source"] ] - + heatmap_matrix2 = label_relations.pivot( index='source', columns='target', values='percentage_source' ) - + heatmap_matrix2 = heatmap_matrix2.fillna(0) - + hover_template = 'Source: %{z}%
Target: %{customdata}%' # Ensure alignment of the text data with the heatmap matrix z = list() @@ -3203,7 +3203,7 @@ def relational_heatmap( 0 if len(z_data_point) == 0 else z_data_point.iloc[0] ) z.append([_ for _ in iter_list]) - + # Create heatmap fig = ff.create_annotated_heatmap( z=z, @@ -3211,7 +3211,7 @@ def relational_heatmap( customdata=heatmap_matrix2.values, hovertemplate=hover_template ) - + fig.update_layout( overwrite=True, xaxis=dict( @@ -3238,15 +3238,15 @@ def relational_heatmap( b=font_size * 2 ) ) - + for i in range(len(fig.layout.annotations)): fig.layout.annotations[i].font.size = font_size - + fig.update_xaxes( side="bottom", tickangle=90 ) - + # Data output section data = fig.data[0] layout = fig.layout @@ -3256,13 +3256,13 @@ def relational_heatmap( matrix.columns=layout['xaxis']['ticktext'] matrix["total"] = matrix.sum(axis=1) matrix = matrix.fillna(0) - + # Display the DataFrame file_name = f"{source_annotation}_to_{target_annotation}" + \ "_relation_matrix.csv" - + return {"figure": fig, "file_name": file_name, "data": matrix} - + def plot_ripley_l( adata, @@ -3274,7 +3274,7 @@ def plot_ripley_l( """ Plot Ripley's L statistic for multiple bins and different regions for a given pair of phenotypes. - + Parameters ---------- adata : AnnData @@ -3290,19 +3290,19 @@ def plot_ripley_l( Whether to return the DataFrame containing the Ripley's L results. kwargs : dict, optional Additional keyword arguments to pass to `seaborn.lineplot`. - + Raises ------ ValueError If the Ripley L results are not found in `adata.uns['ripley_l']`. - + Returns ------- ax : matplotlib.axes.Axes The Axes object containing the plot, which can be further modified. df : pandas.DataFrame, optional The DataFrame containing the Ripley's L results, if `return_df` is True. - + Example ------- >>> ax = plot_ripley_l( @@ -3310,24 +3310,24 @@ def plot_ripley_l( ... phenotypes=('Phenotype1', 'Phenotype2'), ... regions=['region1', 'region2']) >>> plt.show() - + This returns the `Axes` object for further customization and displays the plot. """ - + # Retrieve the results from adata.uns['ripley_l'] ripley_results = adata.uns.get('ripley_l') - + if ripley_results is None: raise ValueError( "Ripley L results not found in the analsyis." ) - + # Filter the results for the specific pair of phenotypes filtered_results = ripley_results[ (ripley_results['center_phenotype'] == phenotypes[0]) & (ripley_results['neighbor_phenotype'] == phenotypes[1]) ] - + if filtered_results.empty: # Generate all unique combinations of phenotype pairs unique_pairs = ripley_results[ @@ -3338,12 +3338,12 @@ def plot_ripley_l( f'\nNeighbor Phenotype: "{phenotypes[1]}"' f"\nExisiting unique pairs: {unique_pairs}" ) - + # If specific regions are provided, filter them, otherwise plot all regions if regions is not None: filtered_results = filtered_results[ filtered_results['region'].isin(regions)] - + # Check if the results are emply after subsetting the regions if filtered_results.empty: available_regions = ripley_results['region'].unique() @@ -3351,16 +3351,16 @@ def plot_ripley_l( f"No data available for the specified regions: {regions}. " f"Available regions: {available_regions}." ) - + # Create a figure and axes fig, ax = plt.subplots(figsize=(10, 10)) - + plot_data = [] - + # Plot Ripley's L for each region for _, row in filtered_results.iterrows(): region = row['region'] # Region label - + if row['ripley_l'] is None: message = ( f"Ripley L results not found for region: {region}" @@ -3383,18 +3383,18 @@ def plot_ripley_l( label=f'{region}: {n_cells}, {int(area)}', ax=ax, **kwargs) - + # Calculate averages for simulations if enabled if sims: sims_stat_df = row["ripley_l"]["sims_stat"] avg_stats = sims_stat_df.groupby("bins")["stats"].mean() avg_used_center_cells = \ sims_stat_df.groupby("bins")["used_center_cells"].mean() - + # Prepare plotted data to return if return_df is True l_stat_data = row['ripley_l']['L_stat'] for _, stat_row in l_stat_data.iterrows(): - + entry = { 'region': region, 'radius': stat_row['bins'], @@ -3404,15 +3404,15 @@ def plot_ripley_l( 'n_neighbor': n_neighbors, 'used_center_cells': stat_row['used_center_cells'] } - + if sims: entry['avg_sim_ripley(radius)'] = \ avg_stats.get(stat_row['bins'], None) entry['avg_sim_used_center_cells'] = \ avg_used_center_cells.get(stat_row['bins'], None) - + plot_data.append(entry) - + if sims: confidence_level = 95 errorbar = ("pi", confidence_level) @@ -3425,7 +3425,7 @@ def plot_ripley_l( label=f"Simulations({region}):{n_sims} runs", **kwargs ) - + # Set labels, title, and grid ax.set_title( "Ripley's L Statistic for phenotypes:" @@ -3433,17 +3433,17 @@ def plot_ripley_l( ) ax.legend(title='Regions:(center, neighbor), area', loc='upper left') ax.grid(True) - + # Set the horizontal axis lable ax.set_xlabel("Radii (pixels)") ax.set_ylabel("Ripley's L Statistic") - + if return_df: df = pd.DataFrame(plot_data) return fig, df - + return fig - + def _prepare_spatial_distance_data( adata, @@ -3456,7 +3456,7 @@ def _prepare_spatial_distance_data( ): """ Prepares a tidy DataFrame for nearest-neighbor (spatial distance) plotting. - + This function: 1) Validates required parameters (annotation, distance_from). 2) Retrieves the spatial distance matrix from @@ -3468,9 +3468,9 @@ def _prepare_spatial_distance_data( 6) Reshapes (melts) into long-form data: columns -> [cellid, group, distance]. 7) Applies optional log1p transform. - + The resulting DataFrame is suitable for plotting with tool like Seaborn. - + Parameters ---------- adata : anndata.AnnData @@ -3491,7 +3491,7 @@ def _prepare_spatial_distance_data( log : bool, optional If True, applies np.log1p transform to the 'distance' column, which is renamed to 'log_distance'. - + Returns ------- pd.DataFrame @@ -3501,14 +3501,14 @@ def _prepare_spatial_distance_data( - 'distance': the numeric distance value. - 'phenotype': the reference phenotype ('distance_from'). - 'stratify_by': optional grouping column, if provided. - + Raises ------ ValueError If required parameters are missing, if phenotypes are not found in `adata.obs`, or if the spatial distance matrix is not available in `adata.obsm`. - + Examples -------- >>> df_long = _prepare_spatial_distance_data( @@ -3522,7 +3522,7 @@ def _prepare_spatial_distance_data( ... ) >>> df_long.head() """ - + # Validate required parameters if distance_from is None: raise ValueError( @@ -3530,15 +3530,15 @@ def _prepare_spatial_distance_data( "the reference group from which distances are measured." ) check_annotation(adata, annotations=annotation) - + # Convert distance_to to list if needed if distance_to is not None and isinstance(distance_to, str): distance_to = [distance_to] - + phenotypes_to_check = [distance_from] + ( distance_to if distance_to else [] ) - + # Ensure distance_from and distance_to exist in adata.obs[annotation] check_label( adata, @@ -3546,7 +3546,7 @@ def _prepare_spatial_distance_data( labels=phenotypes_to_check, should_exist=True ) - + # Retrieve the spatial distance matrix from adata.obsm if spatial_distance not in adata.obsm: raise ValueError( @@ -3556,7 +3556,7 @@ def _prepare_spatial_distance_data( f"Available keys: {list(adata.obsm.keys())}" ) distance_map = adata.obsm[spatial_distance].copy() - + # Verify requested phenotypes exist in the distance_map columns missing_cols = [ p for p in phenotypes_to_check if p not in distance_map.columns @@ -3567,17 +3567,17 @@ def _prepare_spatial_distance_data( f"'{spatial_distance}'. Columns present: " f"{list(distance_map.columns)}" ) - + # Validate 'stratify_by' column if provided if stratify_by is not None: check_annotation(adata, annotations=stratify_by) - + # Build a meta DataFrame with phenotype & optional stratify column meta_data = pd.DataFrame({'phenotype': adata.obs[annotation]}, index=adata.obs.index) if stratify_by: meta_data[stratify_by] = adata.obs[stratify_by] - + # Merge metadata with distance_map and filter for 'distance_from' df_merged = meta_data.join(distance_map, how='left') df_merged = df_merged[df_merged['phenotype'] == distance_from] @@ -3585,15 +3585,15 @@ def _prepare_spatial_distance_data( raise ValueError( f"No cells found with phenotype == '{distance_from}'." ) - + # Reset index to ensure cell names are in a column called 'cellid' df_merged = df_merged.reset_index().rename(columns={'index': 'cellid'}) - + # Prepare the list of metadata columns meta_cols = ['phenotype'] if stratify_by: meta_cols.append(stratify_by) - + # Determine distance columns if distance_to: keep_cols = ['cellid'] + meta_cols + distance_to @@ -3605,41 +3605,41 @@ def _prepare_spatial_distance_data( c for c in df_merged.columns if c not in non_distance_cols ] keep_cols = ['cellid'] + meta_cols + distance_columns - + df_merged = df_merged[keep_cols] - + # Melt the DataFrame from wide to long format df_long = df_merged.melt( id_vars=['cellid'] + meta_cols, var_name='group', value_name='distance' ) - + # Convert columns to categorical for consistency for col in ['group', 'phenotype', stratify_by]: if col and col in df_long.columns: df_long[col] = df_long[col].astype(str).astype('category') - + # Reorder categories for 'group' if 'distance_to' is provided if distance_to: df_long['group'] = df_long['group'].cat.reorder_categories(distance_to) df_long.sort_values('group', inplace=True) - + # Ensure 'distance' is numeric and apply log transform if requested df_long['distance'] = pd.to_numeric(df_long['distance'], errors='coerce') if log: df_long['distance'] = np.log1p(df_long['distance']) df_long.rename(columns={'distance': 'log_distance'}, inplace=True) - + # Reorder columns dynamically based on the presence of 'log' distance_col = 'log_distance' if log else 'distance' final_cols = ['cellid', 'group', distance_col, 'phenotype'] if stratify_by is not None: final_cols.append(stratify_by) df_long = df_long[final_cols] - + return df_long - + def _plot_spatial_distance_dispatch( df_long, @@ -3655,7 +3655,7 @@ def _plot_spatial_distance_dispatch( """ Dispatch a seaborn call to visualise nearest-neighbor distances. Returns Axes object(s) for further customization. - + Layout logic ------------ 1. ``stratify_by`` & ``facet_plot`` → Faceted plot, returns ``Axes`` @@ -3664,7 +3664,7 @@ def _plot_spatial_distance_dispatch( ``List[Axes]`` for the "ax" key. 3. ``stratify_by`` is None → Single plot, returns ``Axes`` or ``List[Axes]`` (if plot_type creates facets) for the "ax" key. - + Parameters ---------- df_long : pd.DataFrame @@ -3698,7 +3698,7 @@ def _plot_spatial_distance_dispatch( **kwargs Extra keyword args propagated to Seaborn. Legend control (e.g. `legend=False`) should be passed here if needed. - + Returns ------- dict @@ -3707,10 +3707,10 @@ def _plot_spatial_distance_dispatch( 'ax' : matplotlib.axes.Axes | list[Axes] } """ - + if method not in ("numeric", "distribution"): raise ValueError("`method` must be 'numeric' or 'distribution'.") - + # Choose plotting function if method == "numeric": _plot_base = partial( @@ -3731,29 +3731,29 @@ def _plot_spatial_distance_dispatch( kind=plot_type, palette=palette, ) - + # Single plotting wrapper to create Axes object(s) def _make_axes_object(_data, **kws_plot): g = _plot_base(data=_data, **kws_plot) - + axis_label = ( "Log(Nearest Neighbor Distance)" if "log" in distance_col else "Nearest Neighbor Distance" ) - + g.set_axis_labels(axis_label, None) - + if g.axes.size == 1: returned_ax = g.ax else: returned_ax = g.axes.flatten().tolist() - + return returned_ax - + # Build axes final_axes_object = None - + if stratify_by and facet_plot: final_axes_object = _make_axes_object( df_long, col=stratify_by, **kwargs @@ -3772,9 +3772,9 @@ def _make_axes_object(_data, **kws_plot): final_axes_object = list_of_all_axes else: final_axes_object = _make_axes_object(df_long, **kwargs) - + return {"data": df_long, "ax": final_axes_object} - + # Build a master HEX palette and cache it inside the AnnData object # ----------------------------------------------------------------------------- @@ -3794,7 +3794,7 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): """ Normalise a CSS-style color string to a hexadecimal value or a valid Matplotlib color name. - + Parameters ---------- col : str @@ -3803,22 +3803,22 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): * 'rgb(r,g,b)' or 'rgba(r,g,b,a)', where r, g, b are 0-255 and a is 0-1 or 0-255 * any named Matplotlib color - + keep_alpha : bool, optional If True and the input includes alpha, return an 8-digit hex; otherwise drop the alpha channel. Default is False. - + Returns ------- str * Lower-case colour name or * 6- or 8-digit lower-case hex. - + Raises ------ ValueError If the color cannot be interpreted. - + Examples -------- >>> _css_rgb_or_hex_to_hex('gold') @@ -3828,9 +3828,9 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): >>> _css_rgb_or_hex_to_hex('rgba(255,0,0,0.5)', keep_alpha=True) '#ff000080' """ - + col = col.strip().lower() - + # Compile the rgb()/rgba() matcher locally to satisfy style request. rgb_re = re.compile( r'rgba?\s*\(' @@ -3841,11 +3841,11 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): r'\s*\)', re.I, ) - + # 1. direct hex if col.startswith('#'): return mcolors.to_hex(col, keep_alpha=keep_alpha).lower() - + # 2. rgb()/rgba() match = rgb_re.fullmatch(col) if match: @@ -3862,14 +3862,14 @@ def _css_rgb_or_hex_to_hex(col, keep_alpha=False): a_val /= 255 rgba.append(a_val) return mcolors.to_hex(rgba, keep_alpha=keep_alpha).lower() - + # 3. named color if col in mcolors.get_named_colors_mapping(): return col # let Matplotlib handle named colors - + # 4. unsupported format raise ValueError(f'Unsupported color format: "{col}"') - + # Helper function (can be defined at module level) def _ordered_unique_figs(axes_list: list): @@ -3883,7 +3883,7 @@ def _ordered_unique_figs(axes_list: list): if fig is not None: seen.setdefault(fig, None) return list(seen) - + def visualize_nearest_neighbor( adata, @@ -3904,21 +3904,21 @@ def visualize_nearest_neighbor( """ Visualize nearest-neighbor (spatial distance) data between groups of cells with optional pin-color map via numeric or distribution plots. - + This landing function first constructs a tidy long-form DataFrame via function `_prepare_spatial_distance_data`, then dispatches plotting to function `_plot_spatial_distance_dispatch`. A pin-color feature guarantees consistent mapping from annotation labels to colors across figures, drawing the mapping from ``adata.uns`` (if present) or generating one automatically through `spac.utils.color_mapping`. - + Plot arrangement logic: 1) If stratify_by is not None and facet_plot=True => single figure with subplots (faceted). 2) If stratify_by is not None and facet_plot=False => multiple separate figures, one per group. 3) If stratify_by is None => a single figure with one plot. - + Parameters ---------- adata : anndata.AnnData @@ -3959,7 +3959,7 @@ def visualize_nearest_neighbor( Only works if plotting a single component. Default is None. **kwargs : dict Additional arguments for seaborn figure-level functions. - + Returns ------- dict @@ -3969,12 +3969,12 @@ def visualize_nearest_neighbor( 'ax': matplotlib.axes.Axes | list[matplotlib.axes.Axes], 'palette': dict # {label: '#rrggbb'} } - + Raises ------ ValueError If required parameters are invalid. - + Examples -------- >>> res = visualize_nearest_neighbor( @@ -3993,19 +3993,19 @@ def visualize_nearest_neighbor( >>> df = res['data'] # long-form DataFrame >>> ax_list[0].set_title('Tumour → Stroma distances') """ - + if method not in ['numeric', 'distribution']: raise ValueError( "Invalid 'method'. Please choose 'numeric' or 'distribution'." ) - + # Determine plot_type if not provided if plot_type is None: plot_type = 'boxen' if method == 'numeric' else 'kde' - + # If log=True, the column name is 'log_distance', else 'distance' distance_col = 'log_distance' if log else 'distance' - + # Build/fetch color palette color_dict_rgb = get_defined_color_map( adata=adata, @@ -4013,14 +4013,14 @@ def visualize_nearest_neighbor( annotations=annotation, colorscale=annotation_colorscale ) - + palette_hex = { k: _css_rgb_or_hex_to_hex(v) for k, v in color_dict_rgb.items() } adata.uns.setdefault('_spac_palettes', {})[ f"{defined_color_map or annotation}_hex" ] = palette_hex - + # Reshape data df_long = _prepare_spatial_distance_data( adata=adata, @@ -4031,7 +4031,7 @@ def visualize_nearest_neighbor( distance_to=distance_to, log=log ) - + # Filter the full palette to include only the target groups present in # df_long['group']. These are the groups that will actually be used for hue # in the plot. @@ -4047,13 +4047,13 @@ def visualize_nearest_neighbor( # plot. For each label we look up its HEX code in ``palette_hex``; if a # colour exists we copy the mapping into the new dictionary. target_groups_in_plot = df_long['group'].astype(str).unique() - + plot_specific_palette = { str(group): palette_hex.get(str(group)) for group in target_groups_in_plot if palette_hex.get(str(group)) is not None } - + # Assemble kwargs & dispatch # Inject the palette into the plotting dispatcher # ----------------------------------------------------------------------------- @@ -4068,7 +4068,7 @@ def visualize_nearest_neighbor( # HOW ``dispatch_kwargs`` starts as a copy of any user‑supplied kwargs; the # call to ``update`` adds these palette‑related keys before control is # handed off to the generic plotting helper. - + dispatch_kwargs = dict(kwargs) dispatch_kwargs.update({ 'hue_axis': 'group', @@ -4076,11 +4076,11 @@ def visualize_nearest_neighbor( }) if method == 'numeric': dispatch_kwargs.setdefault('saturation', 1.0) - + # Set legend=False to allow for custom legend creation by the caller # The user can still override this by passing legend=True in kwargs dispatch_kwargs.setdefault('legend', False) - + disp = _plot_spatial_distance_dispatch( df_long=df_long, method=method, @@ -4090,15 +4090,15 @@ def visualize_nearest_neighbor( distance_col=distance_col, **dispatch_kwargs ) - + returned_axes = disp['ax'] fig_object = None # Initialize - + if isinstance(returned_axes, list): if returned_axes: # Unique figures, preserved in axis order unique_figs_ordered = _ordered_unique_figs(returned_axes) - + if unique_figs_ordered: # at least one valid figure if stratify_by and not facet_plot: # one figure per category → return the ordered list @@ -4120,35 +4120,35 @@ def visualize_nearest_neighbor( # single Axes → grab its figure fig_object = getattr(returned_axes, 'figure', None) # returned_axes is None → fig_object stays None - + return { 'data': disp['data'], 'fig': fig_object, 'ax': disp['ax'], 'palette': plot_specific_palette # Return the filtered palette } - + import json import plotly.graph_objects as go - + def present_summary_as_html(summary_dict: dict) -> str: """ Build an HTML string that presents the summary information intuitively. - + For each specified column, the HTML includes: - Column name and data type - Count and list of missing indices - Summary details presented in a table (for numeric: stats; categorical: unique values and counts) - + Parameters ---------- summary_dict : dict The summary dictionary returned by summarize_dataframe. - + Returns ------- str @@ -4167,7 +4167,7 @@ def present_summary_as_html(summary_dict: dict) -> str: "" "

Data Summary

" ) - + for col, info in summary_dict.items(): html += ( f"

Column: {col}

" @@ -4182,28 +4182,28 @@ def present_summary_as_html(summary_dict: dict) -> str: for key, val in info['summary'].items(): html += f"{key}{val}" html += "
" - + html += "" return html - + def present_summary_as_figure(summary_dict: dict) -> go.Figure: """ Build a static Plotly figure (using a table) to depict the summary dictionary. - + The figure includes columns: - Column name - Data type - Count of missing values - Missing indices (as a string) - Summary details (formatted as JSON for readability) - + Parameters ---------- summary_dict : dict The summary dictionary returned from summarize_dataframe. - + Returns ------- plotly.graph_objects.Figure @@ -4214,13 +4214,13 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: missing_counts = [] missing_indices = [] summaries = [] - + for col, info in summary_dict.items(): col_names.append(col) data_types.append(info['data_type']) missing_counts.append(info['count_missing_indices']) missing_indices.append(str(info['missing_indices'])) - + # need to convert nmpy int64 and float64 to native int and float # so that I can dump them as json clean_data = {} @@ -4234,9 +4234,9 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: else: # Keep the value as is if it's already a standard type clean_data[k] = v - + summaries.append(json.dumps(clean_data, indent=2)) - + fig = go.Figure( data=[go.Table( header=dict( diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 630692eb..f3705d6e 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -99,8 +99,8 @@ def test_histogram_feature(self): bin_edges = list(np.linspace(0.5, 100.5, 101)) fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', bins=bin_edges ).values() @@ -125,7 +125,7 @@ def test_histogram_feature(self): def test_histogram_annotation(self): fig, ax, df = histogram( - self.adata, + self.adata, annotation='annotation1' ).values() total_annotation = len(self.adata.obs['annotation1']) @@ -193,8 +193,8 @@ def test_x_log_scale_transformation(self): ].flatten() fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', x_log_scale=True ).values() @@ -220,8 +220,8 @@ def test_negative_values_x_log_scale(self, mock_print): self.adata.X[0, self.adata.var_names.get_loc('marker1')] = -1 fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', x_log_scale=True ).values() @@ -238,21 +238,21 @@ def test_negative_values_x_log_scale(self, mock_print): def test_title(self): """Test that title changes based on 'layer' information""" fig, ax, df = histogram( - self.adata, + self.adata, feature='marker1' ).values() self.assertEqual(ax.get_title(), 'Layer: Original') fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', layer='Default' ).values() self.assertEqual(ax.get_title(), f'Layer: Default') fig, ax, df = histogram( - self.adata, - annotation='annotation1', + self.adata, + annotation='annotation1', layer='Default' ).values() self.assertEqual(ax.get_title(), '') @@ -260,8 +260,8 @@ def test_title(self): def test_y_log_scale_axis(self): """Test that y_log_scale sets y-axis to log scale.""" fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', y_log_scale=True ).values() self.assertEqual(ax.get_yscale(), 'log') @@ -269,8 +269,8 @@ def test_y_log_scale_axis(self): def test_y_log_scale_label(self): """Test that y-axis label is updated when y_log_scale is True.""" fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', y_log_scale=True ).values() self.assertEqual(ax.get_ylabel(), 'log(Count)') @@ -279,31 +279,31 @@ def test_y_axis_label_based_on_stat(self): """Test that y-axis label changes based on the 'stat' parameter.""" # Test default stat ('count') fig, ax, df = histogram( - self.adata, + self.adata, feature='marker1' ).values() self.assertEqual(ax.get_ylabel(), 'Count') # Test 'frequency' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='frequency' ).values() self.assertEqual(ax.get_ylabel(), 'Frequency') # Test 'density' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='density' ).values() self.assertEqual(ax.get_ylabel(), 'Density') # Test 'probability' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='probability' ).values() self.assertEqual(ax.get_ylabel(), 'Probability') @@ -545,8 +545,8 @@ def test_histogram_feature_integer_bins(self): custom_bins = 10 # Specify number of bins as an integer fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', bins=custom_bins ).values() From 7e16c5932e4088ab0b153bab7af3bdb939c76d77 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 22 Apr 2026 22:55:34 -0400 Subject: [PATCH 52/57] fix(histogram): gate facet hints by facet mode Addressed comment https://github.com/FNLCR-DMAP/SCSAWorkflow/pull/428/changes#r3121891276 - Ignore facet-only layout hints when facet=False in both histogram() and the histogram template - Tighten template figure-size validation to prevent zero value from bypassing - Update histogram tests for ignored layout hints in non-facet path --- src/spac/templates/histogram_template.py | 49 +++++++-------- src/spac/visualization.py | 46 +++++++------- tests/test_visualization/test_histogram.py | 70 ++++++++++++---------- 3 files changed, 88 insertions(+), 77 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 50122c78..549ff93c 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -189,10 +189,9 @@ def run_from_json( multiple = str(multiple).strip().lower() element = str(element).strip().lower() stat = str(stat).strip().lower() - - # Figure size hints use the explicit template token "auto". In facet mode - # it is forwarded as None so core geometry can derive the final figure - # size; in non-facet mode it falls back to 8x6 inches. + # Figure_Width and Figure_Height use "auto" for template defaults. + # In facet mode, it is forwarded as None to derive layout geometry. + # In non-facet mode, it falls back to 8x6 inches. fig_width = text_to_value( fig_width, default_none_text="auto", @@ -207,12 +206,16 @@ def run_from_json( to_float=True, param_name="Figure_Height" ) - if fig_width and fig_height: + if fig_width is not None and fig_height is not None: if fig_width <= 0 or fig_height <= 0: raise ValueError( f'Figure_Width/Height should be a positive number.' f'Received "{fig_width}"/"{fig_height}".' ) + if fig_dpi <= 0: + raise ValueError( + f'Figure_DPI should be a positive number. Received "{fig_dpi}".' + ) # Validate x-axis label rotation if (x_rotate < 0) or (x_rotate > 360): @@ -232,7 +235,8 @@ def run_from_json( param_name="Max_Groups", ) - # Validate facet, group_by, and together parameters for logical consistency + # Facet requires Group_by and forbids Together=True. + # Facet_Ncol accepts "auto" or a positive integer. if facet: if group_by is None: raise ValueError( @@ -243,24 +247,18 @@ def run_from_json( raise ValueError( 'Together and Facet cannot both be True. Please set one to False.' ) - - # facet_ncol uses the explicit template token "auto", or a positive int. - facet_ncol = text_to_value( - facet_ncol, - default_none_text="auto", - value_to_convert_to=None - ) - if facet_ncol is not None: facet_ncol = text_to_value( facet_ncol, + default_none_text="auto", to_int=True, param_name="Facet_Ncol" ) - if facet_ncol <= 0: - raise ValueError( - f'Facet_Ncol must be a positive integer or "auto". ' - f'Received "{facet_ncol}".' - ) + if facet_ncol is not None: + if facet_ncol <= 0: + raise ValueError( + f'Facet_Ncol must be a positive integer or "auto". ' + f'Received "{facet_ncol}".' + ) # Initialize the x-variable before the loop if histplot_by == "Annotation": @@ -278,14 +276,15 @@ def run_from_json( alpha=alpha, stat=stat, max_groups=max_groups, - facet_ncol=facet_ncol, - facet_fig_width=fig_width, - facet_fig_height=fig_height, - facet_tick_rotation=x_rotate, ) # 'multiple' is only applicable when plotting multiple groups together if group_by and together: hist_kwargs["multiple"] = multiple + if facet: + hist_kwargs["facet_ncol"] = facet_ncol + hist_kwargs["facet_fig_width"] = fig_width + hist_kwargs["facet_fig_height"] = fig_height + hist_kwargs["facet_tick_rotation"] = x_rotate result = histogram( adata=adata, @@ -306,9 +305,11 @@ def run_from_json( df_counts = result["df"] # Set figure size and dpi - if fig_width and fig_height: + if fig_width is not None and fig_height is not None: fig.set_size_inches(fig_width, fig_height) + logger.info(f"Set figure size to {fig_width}x{fig_height} inches.") fig.set_dpi(fig_dpi) + logger.info(f"Set figure DPI to {fig_dpi}.") # Ensure axes is a list if isinstance(axs, list): diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 8bf794f7..66a7e991 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -632,8 +632,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, guardrail, which may lead to performance issues or unreadable plots with many groups. - When `facet=True`, these optional key can be passed via `kwargs` - to customize FacetGrid layout: + When `facet=True`, these optional keys can be passed via `kwargs` + to customize FacetGrid layout (they are ignored otherwise): - `facet_ncol`: Controls facet column wrapping. If omitted or passed as `"auto"`, the function uses one column for small group counts and switches to a compact grid for many groups. @@ -810,25 +810,29 @@ def _parse_optional_number( tokens={"unlimited": float('inf')}, ) - # Parse facet layout hints so they never leak to seaborn. - facet_ncol = _parse_optional_number( - "facet_ncol", - kwargs.pop('facet_ncol', None), - kind=int, - positive=True, - tokens={"": None, "auto": None, "none": None}, - ) - facet_fig_width = kwargs.pop('facet_fig_width', None) - facet_fig_height = kwargs.pop('facet_fig_height', None) + # Pop facet-only hints early so they never leak to seaborn. + facet_ncol_raw = kwargs.pop('facet_ncol', None) + facet_fig_width_raw = kwargs.pop('facet_fig_width', None) + facet_fig_height_raw = kwargs.pop('facet_fig_height', None) + facet_tick_rotation_raw = kwargs.pop('facet_tick_rotation', None) + + # Parse facet layout hints only in facet mode. if facet: + facet_ncol = _parse_optional_number( + "facet_ncol", + facet_ncol_raw, + kind=int, + positive=True, + tokens={"": None, "auto": None, "none": None}, + ) facet_fig_width = _parse_optional_number( "facet_fig_width", - facet_fig_width, + facet_fig_width_raw, positive=True, ) facet_fig_height = _parse_optional_number( "facet_fig_height", - facet_fig_height, + facet_fig_height_raw, positive=True, ) if (facet_fig_width is None) != (facet_fig_height is None): @@ -836,15 +840,17 @@ def _parse_optional_number( "Both facet_fig_width and facet_fig_height must be provided together, " "or both must be left as None." ) + facet_tick_rotation = _parse_optional_number( + "facet_tick_rotation", + facet_tick_rotation_raw, + default=0.0, + ) % 360.0 else: - # If not faceting, ignore any provided figure size hints. + # If not faceting, ignore all facet-only hints. + facet_ncol = None facet_fig_width = None facet_fig_height = None - facet_tick_rotation = _parse_optional_number( - "facet_tick_rotation", - kwargs.pop('facet_tick_rotation', None), - default=0.0, - ) % 360.0 + facet_tick_rotation = None # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index f3705d6e..25838bad 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -848,39 +848,6 @@ def test_facet_figure_size_hints_require_pair(self): facet_fig_height=3.5, ) - def test_non_facet_figure_size_hints_are_ignored(self): - """Non-facet calls should ignore facet-only figure-size hints.""" - baseline_fig, baseline_ax, _ = histogram( - self.adata, - feature='marker1', - facet=False, - ).values() - - for hint_kwargs in ( - {'facet_fig_width': 8}, - {'facet_fig_height': 5}, - {'facet_fig_width': 8, 'facet_fig_height': 5}, - ): - with self.subTest(hints=hint_kwargs): - fig, ax, _ = histogram( - self.adata, - feature='marker1', - facet=False, - **hint_kwargs, - ).values() - self.assertAlmostEqual( - fig.get_figwidth(), - baseline_fig.get_figwidth(), - places=6, - ) - self.assertAlmostEqual( - fig.get_figheight(), - baseline_fig.get_figheight(), - places=6, - ) - self.assertGreater(len(ax.patches), 0) - self.assertEqual(len(ax.patches), len(baseline_ax.patches)) - def test_facet_tick_rotation_zero_matches_default_behavior(self): """Explicit zero rotation should match omitted rotation behavior.""" fig_default, _, _ = histogram( @@ -942,6 +909,43 @@ def test_facet_long_label_geometry_respects_explicit_size_hints(self): self.assertAlmostEqual(fig.get_figwidth(), 10.0, places=2) self.assertAlmostEqual(fig.get_figheight(), 4.0, places=2) + def test_non_facet_layout_hints_are_ignored(self): + """Non-facet calls should ignore all facet-only layout hints.""" + baseline_fig, baseline_ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + ).values() + + for hint_kwargs in ( + {'facet_fig_width': 8, 'facet_fig_height': 5}, + { + 'facet_ncol': 0, + 'facet_tick_rotation': 'bad', + 'facet_fig_width': 'wide', + 'facet_fig_height': 'tall', + }, + ): + with self.subTest(hints=hint_kwargs): + fig, ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + **hint_kwargs, + ).values() + self.assertAlmostEqual( + fig.get_figwidth(), + baseline_fig.get_figwidth(), + places=6, + ) + self.assertAlmostEqual( + fig.get_figheight(), + baseline_fig.get_figheight(), + places=6, + ) + self.assertGreater(len(ax.patches), 0) + self.assertEqual(len(ax.patches), len(baseline_ax.patches)) + def test_facet_plot_shared_bins_consistency_numeric(self): """Numeric facets keep shared bins for int/default-like bins inputs.""" # Unbalanced groups: each group occupies only part of the global range. From 88228b6106ebe8e45ceeabd50cc60ecb6d59f82f Mon Sep 17 00:00:00 2001 From: Boqiang Date: Wed, 22 Apr 2026 23:07:33 -0400 Subject: [PATCH 53/57] fix(histogram): ignore max_groups off grouped path Addressed comment https://github.com/FNLCR-DMAP/SCSAWorkflow/pull/428/changes#r3121891313 - Only forward and parse max_groups when group_by is active in the histogram template and histogram() core. - Add focused regression coverage showing non-grouped calls ignore grouped-only max_groups hints. --- src/spac/templates/histogram_template.py | 25 +++++++++++-------- src/spac/visualization.py | 29 +++++++++++++--------- tests/test_visualization/test_histogram.py | 25 +++++++++++++++++++ 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 549ff93c..9602ccc1 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -224,16 +224,23 @@ def run_from_json( f'Received "{x_rotate}".' ) - # max_groups uses a strict template token contract: positive integer or - # the exact keyword "unlimited". Missing values keep the default of 20. + # Max_Groups applies only when Group_by is set. + # It accepts a positive integer or "unlimited". + # Missing values default to 20. if group_by: - if max_groups != "unlimited": - max_groups = text_to_value( - max_groups, + parsed_max_groups = max_groups + if parsed_max_groups != "unlimited": + parsed_max_groups = text_to_value( + parsed_max_groups, value_to_convert_to=20, to_int=True, param_name="Max_Groups", ) + if parsed_max_groups <= 0: + raise ValueError( + f'Max_Groups should be a positive integer or "unlimited". ' + f'Received "{parsed_max_groups}".' + ) # Facet requires Group_by and forbids Together=True. # Facet_Ncol accepts "auto" or a positive integer. @@ -266,20 +273,18 @@ def run_from_json( else: x_var = feature - # In facet mode, Figure_Width/Height are passed as layout hints so - # visualization can derive panel geometry from total figure size: - # panel_width = Figure_Width / ncol, panel_height = Figure_Height / nrow. + # Assemble validated histogram kwargs right before the plotting call. hist_kwargs = dict( element=element, shrink=shrink, bins=bins, alpha=alpha, stat=stat, - max_groups=max_groups, ) - # 'multiple' is only applicable when plotting multiple groups together if group_by and together: hist_kwargs["multiple"] = multiple + if group_by: + hist_kwargs["max_groups"] = parsed_max_groups if facet: hist_kwargs["facet_ncol"] = facet_ncol hist_kwargs["facet_fig_width"] = fig_width diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 66a7e991..e8a29e63 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -626,7 +626,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. - When `group_by` is provided, this optional key can be passed via `kwargs`: + When `group_by` is provided, this optional key can be passed via + `kwargs` (it is ignored otherwise): - `max_groups`: Controls the group-count guardrail for grouped plots. Default is 20 when omitted. Pass `"unlimited"` to disable this guardrail, which may lead to performance issues or unreadable plots @@ -800,22 +801,26 @@ def _parse_optional_number( ) return parsed - # Parse max_groups with "unlimited" handling and validation. - max_groups = _parse_optional_number( - "max_groups", - kwargs.pop('max_groups', None), - kind=int, - default=20, - positive=True, - tokens={"unlimited": float('inf')}, - ) - - # Pop facet-only hints early so they never leak to seaborn. + # Pop grouped/facet-only hints early so they never leak to seaborn. + max_groups_raw = kwargs.pop('max_groups', None) facet_ncol_raw = kwargs.pop('facet_ncol', None) facet_fig_width_raw = kwargs.pop('facet_fig_width', None) facet_fig_height_raw = kwargs.pop('facet_fig_height', None) facet_tick_rotation_raw = kwargs.pop('facet_tick_rotation', None) + # Parse max_groups only for grouped plots; otherwise ignore it entirely. + if group_by: + max_groups = _parse_optional_number( + "max_groups", + max_groups_raw, + kind=int, + default=20, + positive=True, + tokens={"unlimited": float('inf')}, + ) + else: + max_groups = None + # Parse facet layout hints only in facet mode. if facet: facet_ncol = _parse_optional_number( diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index 25838bad..3a6c36e3 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -409,6 +409,31 @@ def test_group_by_invalid_max_groups_raises_value_error(self): max_groups=value, ) + def test_non_grouped_max_groups_is_ignored(self): + """Non-grouped calls should ignore grouped-only max_groups hints.""" + baseline_fig, baseline_ax, _ = histogram( + self.adata, + feature='marker1', + ).values() + + fig, ax, _ = histogram( + self.adata, + feature='marker1', + max_groups=0, + ).values() + self.assertAlmostEqual( + fig.get_figwidth(), + baseline_fig.get_figwidth(), + places=6, + ) + self.assertAlmostEqual( + fig.get_figheight(), + baseline_fig.get_figheight(), + places=6, + ) + self.assertGreater(len(ax.patches), 0) + self.assertEqual(len(ax.patches), len(baseline_ax.patches)) + def test_overlay_options(self): fig, ax, df = histogram( self.adata, From 268ab36166e19ad8805ec912d637e85665add278 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 23 Apr 2026 00:15:43 -0400 Subject: [PATCH 54/57] docs(histogram): standardize documentations - tighten _derive_facet_geometry wording - align nested histogram helper docstrings - sync inline comments with current behavior --- src/spac/visualization.py | 174 +++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 89 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index e8a29e63..02033ecd 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -430,16 +430,17 @@ def _derive_facet_geometry( Requested facet column count. Positive integers are used directly. ``None`` falls back to automatic column selection. facet_fig_width, facet_fig_height : float, optional - Optional total figure-size hints. Geometry is derived from these hints only - when both values are present. + Optional total figure-size hints. Geometry is derived from these + hints only when both values are present. facet_tick_max_chars : int, optional Maximum observed x tick-label length. Expected positive integer. - Used for adjusting default geometry heuristics when explicit figure-size - hints are absent. - ``0`` falls back to the original default geometry without long-label adjustments. + Used for adjusting default geometry heuristics when explicit + figure-size hints are absent. ``0`` falls back to the original + default geometry without long-label adjustments. facet_tick_rotation : float, optional Rotation angle in degrees for x tick labels. Used together with - ``facet_tick_max_chars`` to estimate label burden for default geometry. + ``facet_tick_max_chars`` to estimate label burden for default + geometry. vertical_threshold : int, optional Maximum group count that still prefers a single-column automatic layout. @@ -458,19 +459,18 @@ def _derive_facet_geometry( Returns ------- dict - Dictionary with keys: - - ``facet_ncol``: positive int, normalized column count clamped to ``n_groups``; - - ``facet_height``: float, FacetGrid-ready per-panel height in inches; - - ``facet_aspect``: float, FacetGrid-ready per-panel aspect ratio. - - Automatic layout uses one column when ``n_groups <= vertical_threshold`` - and otherwise uses ``ceil(sqrt(n_groups))`` columns. When both - normalized figure-size hints are present, the helper converts total - figure size into per-panel geometry using the derived grid shape, - applies the minimum panel-size guardrails, and clips aspect into the - configured range. When figure-size hints are absent, the helper may - increase default facet height/aspect for long rotated categorical - labels to preserve usable bar area. + Dictionary containing ``facet_ncol``, ``facet_height``, and + ``facet_aspect`` for FacetGrid construction. + + Automatic layout uses one column when + ``n_groups <= vertical_threshold`` and otherwise uses + ``ceil(sqrt(n_groups))`` columns. When both normalized + figure-size hints are present, the helper converts total figure + size into per-panel geometry, applies minimum panel-size + guardrails, and clips aspect into the configured range. When + figure-size hints are absent, the helper may increase default + facet height/aspect for long rotated categorical labels to + preserve usable bar area. """ # Derive facet_ncol when not explicitly provided, and clamp to n_groups @@ -501,8 +501,9 @@ def _derive_facet_geometry( facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) elif facet_tick_max_chars and facet_tick_max_chars > 0: - # For default geometry only, expand panel ratio when long rotated - # labels would otherwise dominate the available plotting area. + # For default geometry only, allocate more vertical space and a + # slightly tighter aspect when long rotated labels would otherwise + # dominate the available plotting area. rotation = float(facet_tick_rotation or 0.0) % 360.0 rad = np.deg2rad(min(rotation, 180.0)) rotation_factor = 1.0 + 0.8 * np.sin(rad) @@ -514,7 +515,6 @@ def _derive_facet_geometry( facet_height = default_height * (1.0 + 0.35 * pressure) facet_aspect = float( np.clip( - # default_aspect * (1.0 + 0.75 * pressure), default_aspect * (1.0 - 0.05 * pressure), min_aspect, max_aspect, @@ -584,7 +584,9 @@ def histogram(adata, feature=None, annotation=None, layer=None, If True, the y-axis will be set to log scale. facet : bool, default False - If True, group by function outputs facet plots + If True, draw grouped histograms as a faceted layout instead of + separate stacked axes. Requires `group_by` and is not supported + together with `together=True`. **kwargs Additional keyword arguments passed to seaborn histplot function. @@ -733,6 +735,7 @@ def histogram(adata, feature=None, annotation=None, layer=None, def cal_bin_num( num_rows ): + """Return the Rice-rule default number of histogram bins.""" bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') @@ -767,12 +770,7 @@ def _parse_optional_number( positive=False, tokens=None, ): - """Parse an optional numeric hint. - - Returns ``default`` for ``None``, resolves recognized string tokens - before numeric coercion, and optionally enforces finite and positive - values on the parsed result. - """ + """Parse an optional numeric hint with token/default handling.""" if value is None: return default if isinstance(value, str): @@ -857,30 +855,27 @@ def _parse_optional_number( facet_fig_height = None facet_tick_rotation = None - # Function to calculate histogram data def calculate_histogram(data, bins, bin_edges=None): - """ - Compute histogram data for numeric or categorical input. - - Parameters: - - data (pd.Series): The input data to be binned. - - bins (int or sequence): Number of bins (if numeric) or unique categories - (if categorical). - - bin_edges (array-like, optional): Predefined bin edges for numeric data. - If None, automatic binning is used. - - Returns: - - pd.DataFrame: A DataFrame containing the following columns: - - `count`: - Frequency of values in each bin. - - `bin_left`: - Left edge of each bin (for numeric data). - - `bin_right`: - Right edge of each bin (for numeric data). - - `bin_center`: - Center of each bin (for numeric data) or category labels - (for categorical data). + """Compute a histogram-bin table for numeric or categorical input. + Parameters + ---------- + data : pandas.Series + Values to summarize into histogram bins or categorical slots. + bins : int or sequence + Number of bins for numeric data, or categorical slots to + preserve when building grouped categorical histograms. + bin_edges : array-like, optional + Explicit numeric bin edges to reuse. If ``None``, numeric + data uses ``bins`` directly and categorical data ignores this + argument. + + Returns + ------- + pandas.DataFrame + Histogram summary table with ``count``, ``bin_left``, + ``bin_right``, and ``bin_center`` columns. For categorical + input, the bin edge columns repeat the category labels. """ # Check if the data is numeric or categorical @@ -909,7 +904,29 @@ def calculate_histogram(data, bins, bin_edges=None): def build_grouped_histogram_table( plot_data, data_column, group_by, groups, bins ): - """Build per-group histogram-bin tables for grouped histogram paths.""" + """Build grouped histogram-bin data with shared bin definitions. + + Parameters + ---------- + plot_data : pandas.DataFrame + Table containing the histogram source values and grouping + annotation. + data_column : str + Column name to summarize on the x axis. + group_by : str + Annotation column defining the grouping. + groups : list + Group labels to render in plotting order. + bins : int or sequence + Histogram bin specification forwarded to the per-group + histogram builder. + + Returns + ------- + tuple[pandas.DataFrame, array-like] + Combined histogram table for all groups, plus the shared bin + definition used to keep grouped plots aligned. + """ # Determine shared bins across groups for consistent plotting. data_series = plot_data[data_column] if pd.api.types.is_numeric_dtype(data_series): @@ -950,21 +967,8 @@ def build_grouped_histogram_table( hist_data = pd.concat(histograms, ignore_index=True) return hist_data, shared_bins - # Function to compute maximum tick label length for categorical data def compute_max_tick_label_length(data_series): - """Compute maximum tick label length for a categorical data series. - - Parameters - ---------- - data_series : pandas.Series - Categorical data column used to compute maximum tick label length. - - Returns - ------- - int - Maximum number of characters in the tick labels derived from the - unique categories of the input series. - """ + """Return the maximum character length across candidate tick labels.""" if isinstance(data_series.dtype, pd.CategoricalDtype): tick_labels = [ str(label) for label in data_series.cat.categories @@ -975,26 +979,8 @@ def compute_max_tick_label_length(data_series): ] return max((len(label) for label in tick_labels), default=0) - # Function to get axis labels based on log scale and stat parameters def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): - """Resolve x/y axis labels for histogram rendering. - - Parameters - ---------- - data_column : str - Source column used on the x axis. - x_log_scale : bool - Whether x data has log transform semantics. - y_log_scale : bool - Whether y axis is displayed on log scale. - stat : str - Histogram statistic mode (for example, count, density). - - Returns - ------- - tuple[str, str] - Resolved x-axis and y-axis labels. - """ + """Return histogram axis labels for the current scale/stat settings.""" xlabel = f'log({data_column})' if x_log_scale else data_column ylabel_map = { 'count': 'Count', @@ -1009,7 +995,8 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ylabel = f'log({ylabel})' return xlabel, ylabel - # Plotting with or without grouping + # Dispatch to grouped-together, grouped-separate, faceted, or + # ungrouped plotting. if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) @@ -1027,9 +1014,12 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ) if together: + # 1) Grouped together on the same axes if ax is None: fig, ax = plt.subplots() + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. hist_data, shared_bins = build_grouped_histogram_table( plot_data, data_column, @@ -1052,6 +1042,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ax=ax, **kwargs, ) + # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') @@ -1062,6 +1053,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): kwargs.pop('multiple', None) if not facet: + # 2) Grouped separately on different axes fig, ax_array = plt.subplots( n_groups, 1, figsize=(5, 5 * n_groups) ) @@ -1080,6 +1072,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): sns.histplot(data=hist_data, x="bin_center", ax=ax_i, weights='count', **kwargs) + # If plotting feature specify which layer if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') @@ -1087,7 +1080,8 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ax_i.set_title(f'{groups[i]}') axs.append(ax_i) - else: # Facet option + else: + # 3) Faceted by group on different axes using seaborn's FacetGrid. # Compute max label length only when not explicitly provided. facet_tick_max_chars = 0 if not pd.api.types.is_numeric_dtype(plot_data[data_column]): @@ -1135,7 +1129,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): **kwargs, ) - # Keep shared scale but show x tick numbers on bottom row and y tick numbers on left column + # Show tick labels on every facet while keeping shared axes. for ax_i in hist.axes.flat: ax_i.tick_params(axis='x', labelbottom=True) ax_i.tick_params(axis='y', labelleft=True) @@ -1164,6 +1158,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): axs.extend(hist.axes.flat) else: + # 4) Ungrouped histogram (group_by=None) if ax is None: fig, ax = plt.subplots() @@ -1186,6 +1181,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): ax.set_title(f'Layer: {layer}') axs.append(ax) + # Determine axis labels based on scale and stat settings. stat = kwargs.get('stat', 'count') xlabel, ylabel = resolve_hist_axis_labels( data_column=data_column, @@ -1199,7 +1195,7 @@ def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): # Set axis scales if y_log_scale is True if y_log_scale: ax.set_yscale('log') - + # For faceted plots, we set axis labels at the figure level only. if facet: ax.set_xlabel('') ax.set_ylabel('') From 0c73cfe04ae753828b2478318299dd813c16eaa0 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 23 Apr 2026 00:59:21 -0400 Subject: [PATCH 55/57] chore(import): remove unused imports --- src/spac/utils.py | 3 +-- src/spac/visualization.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spac/utils.py b/src/spac/utils.py index 5d5d87b3..68395548 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -8,12 +8,11 @@ import warnings import numbers from scipy.stats import median_abs_deviation -from typing import Any, List, Optional +from typing import List, Optional # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) def regex_search_list( diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 02033ecd..c54baa93 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -26,7 +26,7 @@ import time import json import re -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union import matplotlib.colors as mcolors import matplotlib.patches as mpatch from functools import partial @@ -501,9 +501,9 @@ def _derive_facet_geometry( facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) elif facet_tick_max_chars and facet_tick_max_chars > 0: - # For default geometry only, allocate more vertical space and a - # slightly tighter aspect when long rotated labels would otherwise - # dominate the available plotting area. + # For default geometry only, allocate more vertical space and a + # slightly tighter aspect when long rotated labels would otherwise + # dominate the available plotting area. rotation = float(facet_tick_rotation or 0.0) % 360.0 rad = np.deg2rad(min(rotation, 180.0)) rotation_factor = 1.0 + 0.8 * np.sin(rad) From a422bbec2179a4f123de653febf399c2752453e4 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Thu, 23 Apr 2026 00:59:48 -0400 Subject: [PATCH 56/57] fix(histogram): gate multiple to overlay mode - normalize `multiple` only when `group_by` and `together` are active --- src/spac/templates/histogram_template.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 9602ccc1..d2d508a3 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -186,9 +186,12 @@ def run_from_json( "Setting bin number calculation to auto." ) - multiple = str(multiple).strip().lower() + + if group_by and together: + multiple = str(multiple).strip().lower() element = str(element).strip().lower() stat = str(stat).strip().lower() + # Figure_Width and Figure_Height use "auto" for template defaults. # In facet mode, it is forwarded as None to derive layout geometry. # In non-facet mode, it falls back to 8x6 inches. From 7a6c048ce61c0d32b766a08c301d7b5e9a1e24f9 Mon Sep 17 00:00:00 2001 From: Boqiang Date: Tue, 28 Apr 2026 00:09:33 -0400 Subject: [PATCH 57/57] fix(histogram): widen facet spacing in template layout - This is a follow-up fix when exposed to shiny --- src/spac/templates/histogram_template.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index d2d508a3..c14bdc65 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -393,14 +393,14 @@ def run_from_json( rows = len({round(ax.get_position().y0, 3) for ax in axes}) fig.tight_layout( rect=[ - max(0.02, 0.038 - 0.004 * rows), + min(0.030, 0.02 + 0.0025 * rows), max(0.022, 0.036 - 0.003 * rows), min(0.992, 0.98 + 0.0025 * rows), - max(0.974, 0.98 - 0.001 * rows), + max(0.969, 0.975 - 0.001 * rows), ], pad=max(0.35, 0.6 - 0.05 * rows), - h_pad=max(0.2, 0.43 - 0.04 * rows), - w_pad=max(0.2, 0.43 - 0.04 * rows), + h_pad=max(0.2, 0.43 - 0.04 * rows) * 6, + w_pad=max(0.2, 0.43 - 0.04 * rows) * 6, ) else: fig.tight_layout()