diff --git a/.gitignore b/.gitignore index 8ec0749c..528a7356 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ .vscode/launch.json test/__pychache__ build/ -*.egg-info/ +*.egg-info/ \ No newline at end of file diff --git a/src/spac/templates/histogram_template.py b/src/spac/templates/histogram_template.py index 0a3924d4..c14bdc65 100644 --- a/src/spac/templates/histogram_template.py +++ b/src/spac/templates/histogram_template.py @@ -96,10 +96,11 @@ def run_from_json( feature = text_to_value(params.get("Feature", "None")) annotation = text_to_value(params.get("Annotation", "None")) layer = params.get("Table_", "Original") - group_by = params.get("Group_by", "None") + group_by = text_to_value(params.get("Group_by", "None")) + max_groups = params.get("Max_Groups", 20) together = params.get("Together", True) - fig_width = params.get("Figure_Width", 8) - fig_height = params.get("Figure_Height", 6) + fig_width = params.get("Figure_Width", "auto") + fig_height = params.get("Figure_Height", "auto") font_size = params.get("Font_Size", 12) fig_dpi = params.get("Figure_DPI", 300) legend_location = params.get("Legend_Location", "best") @@ -107,12 +108,15 @@ def run_from_json( take_X_log = params.get("Take_X_Log", False) take_Y_log = params.get("Take_Y_log", False) multiple = params.get("Multiple", "dodge") + element = params.get("Element", "bars") shrink = params.get("Shrink_Number", 1) bins = params.get("Bins", "auto") alpha = params.get("Bin_Transparency", 0.75) stat = params.get("Stat", "count") x_rotate = params.get("X_Axis_Label_Rotation", 0) histplot_by = params.get("Plot_By", "Annotation") + facet = params.get("Facet", False) + facet_ncol = params.get("Facet_Ncol", "auto") # Close all existing figures to prevent extra plots plt.close('all') @@ -151,7 +155,9 @@ def run_from_json( 'No features available in adata.var_names to plot.' ) - # Validate and set bins + # Bins use a strict template contract in feature mode: + # "auto" or a positive integer. Loose null-like values are intentionally + # not treated as aliases here. if feature is not None: bins = text_to_value( bins, @@ -180,33 +186,126 @@ def run_from_json( "Setting bin number calculation to auto." ) + + if group_by and together: + multiple = str(multiple).strip().lower() + element = str(element).strip().lower() + stat = str(stat).strip().lower() + + # Figure_Width and Figure_Height use "auto" for template defaults. + # In facet mode, it is forwarded as None to derive layout geometry. + # In non-facet mode, it falls back to 8x6 inches. + fig_width = text_to_value( + fig_width, + default_none_text="auto", + value_to_convert_to=None if facet else 8, + to_float=True, + param_name="Figure_Width" + ) + fig_height = text_to_value( + fig_height, + default_none_text="auto", + value_to_convert_to=None if facet else 6, + to_float=True, + param_name="Figure_Height" + ) + if fig_width is not None and fig_height is not None: + if fig_width <= 0 or fig_height <= 0: + raise ValueError( + f'Figure_Width/Height should be a positive number.' + f'Received "{fig_width}"/"{fig_height}".' + ) + if fig_dpi <= 0: + raise ValueError( + f'Figure_DPI should be a positive number. Received "{fig_dpi}".' + ) + + # Validate x-axis label rotation if (x_rotate < 0) or (x_rotate > 360): raise ValueError( f'The X label rotation should fall within 0 to 360 degree. ' f'Received "{x_rotate}".' ) + # Max_Groups applies only when Group_by is set. + # It accepts a positive integer or "unlimited". + # Missing values default to 20. + if group_by: + parsed_max_groups = max_groups + if parsed_max_groups != "unlimited": + parsed_max_groups = text_to_value( + parsed_max_groups, + value_to_convert_to=20, + to_int=True, + param_name="Max_Groups", + ) + if parsed_max_groups <= 0: + raise ValueError( + f'Max_Groups should be a positive integer or "unlimited". ' + f'Received "{parsed_max_groups}".' + ) + + # Facet requires Group_by and forbids Together=True. + # Facet_Ncol accepts "auto" or a positive integer. + if facet: + if group_by is None: + raise ValueError( + 'Facet is True but Group_by is not specified. ' + 'Please specify Group_by when using Facet.' + ) + if together: + raise ValueError( + 'Together and Facet cannot both be True. Please set one to False.' + ) + facet_ncol = text_to_value( + facet_ncol, + default_none_text="auto", + to_int=True, + param_name="Facet_Ncol" + ) + if facet_ncol is not None: + if facet_ncol <= 0: + raise ValueError( + f'Facet_Ncol must be a positive integer or "auto". ' + f'Received "{facet_ncol}".' + ) + # Initialize the x-variable before the loop if histplot_by == "Annotation": x_var = annotation else: x_var = feature + # Assemble validated histogram kwargs right before the plotting call. + hist_kwargs = dict( + element=element, + shrink=shrink, + bins=bins, + alpha=alpha, + stat=stat, + ) + if group_by and together: + hist_kwargs["multiple"] = multiple + if group_by: + hist_kwargs["max_groups"] = parsed_max_groups + if facet: + hist_kwargs["facet_ncol"] = facet_ncol + hist_kwargs["facet_fig_width"] = fig_width + hist_kwargs["facet_fig_height"] = fig_height + hist_kwargs["facet_tick_rotation"] = x_rotate + result = histogram( adata=adata, feature=feature, annotation=annotation, layer=text_to_value(layer, "Original"), - group_by=text_to_value(group_by), + group_by=group_by, together=together, ax=None, x_log_scale=take_X_log, y_log_scale=take_Y_log, - multiple=multiple, - shrink=shrink, - bins=bins, - alpha=alpha, - stat=stat + facet=facet, + **hist_kwargs, ) fig = result["fig"] @@ -214,8 +313,11 @@ def run_from_json( df_counts = result["df"] # Set figure size and dpi - fig.set_size_inches(fig_width, fig_height) + if fig_width is not None and fig_height is not None: + fig.set_size_inches(fig_width, fig_height) + logger.info(f"Set figure size to {fig_width}x{fig_height} inches.") fig.set_dpi(fig_dpi) + logger.info(f"Set figure DPI to {fig_dpi}.") # Ensure axes is a list if isinstance(axs, list): @@ -249,8 +351,12 @@ def run_from_json( # Rotate x labels ax.tick_params(axis='x', rotation=x_rotate) + if x_rotate: + for label in ax.get_xticklabels(): + label.set_rotation_mode('anchor') + label.set_horizontalalignment('right') - # Set titles based on group_by + # Set titles based on group_by and facet if text_to_value(group_by): if together: for ax in axes: @@ -267,15 +373,37 @@ def run_from_json( "Number of axes does not match number of " "groups. Titles may not correspond correctly." ) + if facet: + fig.suptitle( + f'Histogram of "{x_var}" faceted by "{group_by}"' + ) + ax_title_prefix = f'Group' + else: + ax_title_prefix = f'Histogram of "{x_var}" for group' for ax, grp in zip(axes, unique_groups): ax.set_title( - f'Histogram of "{x_var}" for group: "{grp}"' + f'{ax_title_prefix}: "{grp}"' ) else: for ax in axes: ax.set_title(f'Count plot of "{x_var}"') - plt.tight_layout() + # Adjust layout to prevent title overlap + if facet: + rows = len({round(ax.get_position().y0, 3) for ax in axes}) + fig.tight_layout( + rect=[ + min(0.030, 0.02 + 0.0025 * rows), + max(0.022, 0.036 - 0.003 * rows), + min(0.992, 0.98 + 0.0025 * rows), + max(0.969, 0.975 - 0.001 * rows), + ], + pad=max(0.35, 0.6 - 0.05 * rows), + h_pad=max(0.2, 0.43 - 0.04 * rows) * 6, + w_pad=max(0.2, 0.43 - 0.04 * rows) * 6, + ) + else: + fig.tight_layout() logger.info("Displaying top 10 rows of histogram dataframe:") print(df_counts.head(10)) diff --git a/src/spac/utils.py b/src/spac/utils.py index 2a616b68..68395548 100644 --- a/src/spac/utils.py +++ b/src/spac/utils.py @@ -1196,20 +1196,20 @@ def compute_metrics(data): # compute summary statistics for the specified columns def compute_summary_qc_stats( - df: pd.DataFrame, + df: pd.DataFrame, n_mad: int = 5, upper_quantile: float = 0.95, lower_quantile: float = 0.05, stat_columns_list: List[str] = ['nFeature', 'nCount', 'percent.mt'] ) -> pd.DataFrame: - + """ Compute summary quality control statistics for specified columns in a dataset. For each column in stat_columns_list, this function calculates: - Mean - Median - - Upper and lower thresholds based on median ± n_mad * MAD + - Upper and lower thresholds based on median ± n_mad * MAD (median absolute deviation) - Upper and lower quantiles @@ -1230,7 +1230,7 @@ def compute_summary_qc_stats( ------- pd.DataFrame DataFrame with summary statistics for each specified column. - Columns: ["metric_name", "mean", "median", "upper_mad", "lower_mad", + Columns: ["metric_name", "mean", "median", "upper_mad", "lower_mad", "upper_quantile", "lower_quantile"] Raises @@ -1269,8 +1269,8 @@ def compute_summary_qc_stats( return pd.DataFrame( stat_vals, columns=[ - "metric_name", "mean", "median", - "upper_mad", "lower_mad", + "metric_name", "mean", "median", + "upper_mad", "lower_mad", "upper_quantile", "lower_quantile" ] - ) \ No newline at end of file + ) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9003c163..c54baa93 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -403,9 +403,142 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): return fig, ax + +def _derive_facet_geometry( + n_groups, + facet_ncol=None, + facet_fig_width=None, + facet_fig_height=None, + facet_tick_max_chars=0, + facet_tick_rotation=0.0, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + min_panel_width=1.8, + min_panel_height=1.6, + min_aspect=0.6, + max_aspect=2.0, +): + """Derive FacetGrid geometry from pre-normalized facet layout hints. + + Parameters + ---------- + n_groups : int + Number of facet panels. Expected to be a positive integer supplied by + the grouped histogram path. + facet_ncol : int or None, optional + Requested facet column count. Positive integers are used directly. + ``None`` falls back to automatic column selection. + facet_fig_width, facet_fig_height : float, optional + Optional total figure-size hints. Geometry is derived from these + hints only when both values are present. + facet_tick_max_chars : int, optional + Maximum observed x tick-label length. Expected positive integer. + Used for adjusting default geometry heuristics when explicit + figure-size hints are absent. ``0`` falls back to the original + default geometry without long-label adjustments. + facet_tick_rotation : float, optional + Rotation angle in degrees for x tick labels. Used together with + ``facet_tick_max_chars`` to estimate label burden for default + geometry. + vertical_threshold : int, optional + Maximum group count that still prefers a single-column automatic + layout. + default_height : float, optional + Per-facet panel height used when explicit figure-size hints are not + available. + default_aspect : float, optional + Per-facet panel aspect ratio used when explicit figure-size hints are + not available. + min_panel_width, min_panel_height : float, optional + Lower bounds applied to per-panel dimensions when figure-size hints are + converted into FacetGrid geometry. + min_aspect, max_aspect : float, optional + Bounds applied to the derived panel aspect ratio. + + Returns + ------- + dict + Dictionary containing ``facet_ncol``, ``facet_height``, and + ``facet_aspect`` for FacetGrid construction. + + Automatic layout uses one column when + ``n_groups <= vertical_threshold`` and otherwise uses + ``ceil(sqrt(n_groups))`` columns. When both normalized + figure-size hints are present, the helper converts total figure + size into per-panel geometry, applies minimum panel-size + guardrails, and clips aspect into the configured range. When + figure-size hints are absent, the helper may increase default + facet height/aspect for long rotated categorical labels to + preserve usable bar area. + """ + + # Derive facet_ncol when not explicitly provided, and clamp to n_groups + if facet_ncol is None: + if n_groups <= vertical_threshold: + facet_ncol = 1 + else: + facet_ncol = int(np.ceil(np.sqrt(n_groups))) + logging.info( + "Automatic facet_ncol selection: %s columns for %s groups " + "(vertical_threshold=%s).", + facet_ncol, + n_groups, + vertical_threshold, + ) + facet_ncol = max(1, min(int(facet_ncol), n_groups)) + + # Use defaults if figure-size hints are not provided + facet_height = default_height + facet_aspect = default_aspect + + # Derive facet geometry from figure-size hints when both are provided + if facet_fig_width is not None and facet_fig_height is not None: + nrow = int(np.ceil(n_groups / facet_ncol)) + panel_width = max(facet_fig_width / facet_ncol, min_panel_width) + panel_height = max(facet_fig_height / nrow, min_panel_height) + facet_height = panel_height + facet_aspect = float(np.clip(panel_width / panel_height, min_aspect, max_aspect)) + + elif facet_tick_max_chars and facet_tick_max_chars > 0: + # For default geometry only, allocate more vertical space and a + # slightly tighter aspect when long rotated labels would otherwise + # dominate the available plotting area. + rotation = float(facet_tick_rotation or 0.0) % 360.0 + rad = np.deg2rad(min(rotation, 180.0)) + rotation_factor = 1.0 + 0.8 * np.sin(rad) + burden = float(facet_tick_max_chars) * rotation_factor + long_label_threshold = 12.0 + + if burden > long_label_threshold: + pressure = min((burden - long_label_threshold) / long_label_threshold, 2.0) + facet_height = default_height * (1.0 + 0.35 * pressure) + facet_aspect = float( + np.clip( + default_aspect * (1.0 - 0.05 * pressure), + min_aspect, + max_aspect, + ) + ) + logging.info( + "Automatic facet geometry adjustment for long x tick labels: " + "max_chars=%s, rotation=%s, facet_height=%.2f, facet_aspect=%.2f.", + facet_tick_max_chars, + facet_tick_rotation, + facet_height, + facet_aspect, + ) + + return { + "facet_ncol": facet_ncol, + "facet_height": facet_height, + "facet_aspect": facet_aspect, + } + + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, - x_log_scale=False, y_log_scale=False, **kwargs): + x_log_scale=False, y_log_scale=False, facet=False, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. @@ -431,7 +564,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, If True, and if group_by != None, create one plot combining all groups. If False, create separate histograms for each group. The appearance of combined histograms can be controlled using the - `multiple` and `element` parameters in **kwargs. + `multiple` and `element` parameters in **kwargs. Separate grouped or + faceted histograms ignore `multiple`. To control how the histograms are normalized (e.g., to divide the histogram by the number of elements in every group), use the `stat` parameter in **kwargs. For example, set `stat="probability"` to show @@ -439,6 +573,8 @@ def histogram(adata, feature=None, annotation=None, layer=None, ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. + Not supported for grouped-separate (`group_by` with `together=False`) + or facet layouts (`group_by` with `facet=True`). x_log_scale : bool, default False If True, the data will be transformed using np.log1p before plotting, @@ -447,11 +583,17 @@ def histogram(adata, feature=None, annotation=None, layer=None, y_log_scale : bool, default False If True, the y-axis will be set to log scale. + facet : bool, default False + If True, draw grouped histograms as a faceted layout instead of + separate stacked axes. Requires `group_by` and is not supported + together with `together=True`. + **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: - `multiple`: Determines how the subsets of data are displayed - on the same axes. Options include: + on the same axes. Ignored when `group_by` is used with + `together=False`. Options include: * "layer": Draws each subset on top of the other without adjustments. * "dodge": Dodges bars for each subset side by side. @@ -463,6 +605,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, * "step": Creates a step line plot without bars. * "poly": Creates a polygon where the bottom edge represents the x-axis and the top edge the histogram's bins. + - `shrink`: Scale bar width relative to each bin's width. + Useful for reducing overlap with `multiple="dodge"`. + - `alpha`: Transparency for histogram artists. + 0 is fully transparent and 1 is fully opaque. - `log_scale`: Determines if the data should be plotted on a logarithmic scale. - `stat`: Determines the statistical transformation to use on the data @@ -478,9 +624,28 @@ def histogram(adata, feature=None, annotation=None, layer=None, Can be a number (indicating the number of bins) or a list (indicating bin edges). For example, `bins=10` will create 10 bins, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. - If not provided, the binning will be determined automatically. + If not provided, or if passed as `None`/`"auto"`/`"none"`, + the binning will be determined automatically using the Rice rule. Note, don't pass a numpy array, only python lists or strs/numbers. + When `group_by` is provided, this optional key can be passed via + `kwargs` (it is ignored otherwise): + - `max_groups`: Controls the group-count guardrail for grouped plots. + Default is 20 when omitted. Pass `"unlimited"` to disable this + guardrail, which may lead to performance issues or unreadable plots + with many groups. + + When `facet=True`, these optional keys can be passed via `kwargs` + to customize FacetGrid layout (they are ignored otherwise): + - `facet_ncol`: Controls facet column wrapping. + If omitted or passed as `"auto"`, the function uses one column for + small group counts and switches to a compact grid for many groups. + Otherwise, the provided value is used to request the facet column + count. + - `facet_fig_width`: float, intended final figure width in inches. + - `facet_fig_height`: float, intended final figure height in inches. + - `facet_tick_rotation`: float, rotation angle in degrees for x tick labels. + Returns ------- A dictionary containing the following: @@ -547,10 +712,15 @@ def histogram(adata, feature=None, annotation=None, layer=None, else: df[data_column] = np.log1p(df[data_column]) + # If ax is provided, validate input and get figure from it. + # If not, the figure will be created in the plotting branch. if ax is not None: + if group_by and not together: + raise ValueError( + "External ax is only supported for single-axes histogram " + "Please set together=True or remove external ax." + ) fig = ax.get_figure() - else: - fig, ax = plt.subplots() axs = [] @@ -565,41 +735,147 @@ def histogram(adata, feature=None, annotation=None, layer=None, def cal_bin_num( num_rows ): + """Return the Rice-rule default number of histogram bins.""" bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') - return(bins) + + return (bins) num_rows = plot_data.shape[0] - # Check if bins is being passed - # If not, the in house algorithm will compute the number of bins - if 'bins' not in kwargs: + # Check if bins is not being passed or set to None or "auto" in kwargs. + # If so, the in house algorithm will compute the number of bins + bins_kwarg = kwargs.get('bins', None) + if isinstance(bins_kwarg, str): + bins_kwarg = bins_kwarg.strip().lower() + if bins_kwarg in {'', 'auto', 'none'}: + bins_kwarg = None + if bins_kwarg is None: kwargs['bins'] = cal_bin_num(num_rows) - # Function to calculate histogram data + # Input validation for facet + if facet: + if group_by is None: + raise ValueError("group_by must be specified when facet=True.") + if together: + raise ValueError("Cannot use together=True with facet=True," + " choose one.") + + def _parse_optional_number( + name, + value, + *, + kind=float, + default=None, + positive=False, + tokens=None, + ): + """Parse an optional numeric hint with token/default handling.""" + if value is None: + return default + if isinstance(value, str): + value = value.strip() + if tokens and value.lower() in tokens: + return tokens[value.lower()] + expected = ( + f'{"positive " if positive else ""}{kind.__name__}' + f'{" or a supported keyword" if tokens else ""}' + ) + if isinstance(value, bool): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + try: + parsed = kind(value) + except (TypeError, ValueError): + raise ValueError(f'{name} must be a {expected}. Received "{value}".') + if not math.isfinite(parsed): + raise ValueError( + f'{name} must be a finite {kind.__name__}. ' + f'Received "{value}".' + ) + if positive and parsed <= 0: + raise ValueError( + f'{name} must be a positive {kind.__name__}. ' + f'Received "{value}".' + ) + return parsed + + # Pop grouped/facet-only hints early so they never leak to seaborn. + max_groups_raw = kwargs.pop('max_groups', None) + facet_ncol_raw = kwargs.pop('facet_ncol', None) + facet_fig_width_raw = kwargs.pop('facet_fig_width', None) + facet_fig_height_raw = kwargs.pop('facet_fig_height', None) + facet_tick_rotation_raw = kwargs.pop('facet_tick_rotation', None) + + # Parse max_groups only for grouped plots; otherwise ignore it entirely. + if group_by: + max_groups = _parse_optional_number( + "max_groups", + max_groups_raw, + kind=int, + default=20, + positive=True, + tokens={"unlimited": float('inf')}, + ) + else: + max_groups = None + + # Parse facet layout hints only in facet mode. + if facet: + facet_ncol = _parse_optional_number( + "facet_ncol", + facet_ncol_raw, + kind=int, + positive=True, + tokens={"": None, "auto": None, "none": None}, + ) + facet_fig_width = _parse_optional_number( + "facet_fig_width", + facet_fig_width_raw, + positive=True, + ) + facet_fig_height = _parse_optional_number( + "facet_fig_height", + facet_fig_height_raw, + positive=True, + ) + if (facet_fig_width is None) != (facet_fig_height is None): + raise ValueError( + "Both facet_fig_width and facet_fig_height must be provided together, " + "or both must be left as None." + ) + facet_tick_rotation = _parse_optional_number( + "facet_tick_rotation", + facet_tick_rotation_raw, + default=0.0, + ) % 360.0 + else: + # If not faceting, ignore all facet-only hints. + facet_ncol = None + facet_fig_width = None + facet_fig_height = None + facet_tick_rotation = None + def calculate_histogram(data, bins, bin_edges=None): - """ - Compute histogram data for numeric or categorical input. - - Parameters: - - data (pd.Series): The input data to be binned. - - bins (int or sequence): Number of bins (if numeric) or unique categories - (if categorical). - - bin_edges (array-like, optional): Predefined bin edges for numeric data. - If None, automatic binning is used. - - Returns: - - pd.DataFrame: A DataFrame containing the following columns: - - `count`: - Frequency of values in each bin. - - `bin_left`: - Left edge of each bin (for numeric data). - - `bin_right`: - Right edge of each bin (for numeric data). - - `bin_center`: - Center of each bin (for numeric data) or category labels - (for categorical data). + """Compute a histogram-bin table for numeric or categorical input. + + Parameters + ---------- + data : pandas.Series + Values to summarize into histogram bins or categorical slots. + bins : int or sequence + Number of bins for numeric data, or categorical slots to + preserve when building grouped categorical histograms. + bin_edges : array-like, optional + Explicit numeric bin edges to reuse. If ``None``, numeric + data uses ``bins`` directly and categorical data ignores this + argument. + Returns + ------- + pandas.DataFrame + Histogram summary table with ``count``, ``bin_left``, + ``bin_right``, and ``bin_center`` columns. For categorical + input, the bin edge columns repeat the category labels. """ # Check if the data is numeric or categorical @@ -625,102 +901,272 @@ def calculate_histogram(data, bins, bin_edges=None): 'count': counts.values }) - # Plotting with or without grouping + def build_grouped_histogram_table( + plot_data, data_column, group_by, groups, bins + ): + """Build grouped histogram-bin data with shared bin definitions. + + Parameters + ---------- + plot_data : pandas.DataFrame + Table containing the histogram source values and grouping + annotation. + data_column : str + Column name to summarize on the x axis. + group_by : str + Annotation column defining the grouping. + groups : list + Group labels to render in plotting order. + bins : int or sequence + Histogram bin specification forwarded to the per-group + histogram builder. + + Returns + ------- + tuple[pandas.DataFrame, array-like] + Combined histogram table for all groups, plus the shared bin + definition used to keep grouped plots aligned. + """ + # Determine shared bins across groups for consistent plotting. + data_series = plot_data[data_column] + if pd.api.types.is_numeric_dtype(data_series): + shared_bins = np.histogram_bin_edges(data_series, bins=bins) + elif isinstance(data_series.dtype, pd.CategoricalDtype): + shared_bins = data_series.cat.categories + else: + shared_bins = pd.Index(pd.unique(data_series.dropna())) + + # Compute histograms for each group using the shared bins. + histograms = [] + for group in groups: + group_data = plot_data.loc[ + plot_data[group_by] == group, data_column + ] + if pd.api.types.is_numeric_dtype(data_series): + group_hist = calculate_histogram( + group_data, bins, bin_edges=shared_bins + ) + else: + # For categorical data, pad missing categories with zero counts + # to ensure consistent plotting. + group_hist = calculate_histogram(group_data, bins) + group_hist = ( + group_hist + .set_index('bin_center') + .reindex(shared_bins) + .rename_axis('bin_center') + .reset_index() + ) + group_hist['count'] = group_hist['count'].fillna(0) + group_hist['bin_left'] = group_hist['bin_center'] + group_hist['bin_right'] = group_hist['bin_center'] + group_hist[group_by] = group + histograms.append(group_hist) + + # Concatenate all group histograms into a single DataFrame for plotting. + hist_data = pd.concat(histograms, ignore_index=True) + return hist_data, shared_bins + + def compute_max_tick_label_length(data_series): + """Return the maximum character length across candidate tick labels.""" + if isinstance(data_series.dtype, pd.CategoricalDtype): + tick_labels = [ + str(label) for label in data_series.cat.categories + ] + else: + tick_labels = [ + str(label) for label in data_series.dropna().unique().tolist() + ] + return max((len(label) for label in tick_labels), default=0) + + def resolve_hist_axis_labels(data_column, x_log_scale, y_log_scale, stat): + """Return histogram axis labels for the current scale/stat settings.""" + xlabel = f'log({data_column})' if x_log_scale else data_column + ylabel_map = { + 'count': 'Count', + 'frequency': 'Frequency', + 'density': 'Density', + 'probability': 'Probability', + "proportion": "Proportion", + "percent": "Percent" + } + ylabel = ylabel_map.get(stat, 'Count') + if y_log_scale: + ylabel = f'log({ylabel})' + return xlabel, ylabel + + # Dispatch to grouped-together, grouped-separate, faceted, or + # ungrouped plotting. if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) + if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") + elif n_groups > max_groups: + raise ValueError( + "The number of groups in `group_by` exceeds `max_groups`: " + f"found {n_groups}, threshold {max_groups}.\n" + "Please reduce/bucket groups or use another grouping column, or " + "pass a larger `max_groups`.\n" + "See `kwargs` documentation for more details." + ) if together: - # Compute global bin edges based on the entire dataset + # 1) Grouped together on the same axes + if ax is None: + fig, ax = plt.subplots() + + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) if pd.api.types.is_numeric_dtype(plot_data[data_column]): - global_bin_edges = np.histogram_bin_edges( - plot_data[data_column], bins=kwargs['bins'] - ) - else: - global_bin_edges = plot_data[data_column].unique() - - hist_data = [] - # Compute histograms for each group separately and combine them - for group in groups: - group_data = plot_data[ - plot_data[group_by] == group - ][data_column] - group_hist = calculate_histogram(group_data, kwargs['bins'], - bin_edges=global_bin_edges) - group_hist[group_by] = group - hist_data.append(group_hist) - hist_data = pd.concat(hist_data, ignore_index=True) + kwargs['bins'] = shared_bins.tolist() # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") + sns.histplot( + data=hist_data, + x='bin_center', + weights='count', + hue=group_by, + ax=ax, + **kwargs, + ) - sns.histplot(data=hist_data, x='bin_center', weights='count', - hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer if feature: ax.set_title(f'Layer: {layer}') axs.append(ax) + else: - fig, ax_array = plt.subplots( - n_groups, 1, figsize=(5, 5 * n_groups) - ) + # 'multiple' parameter is not applicable + kwargs.pop('multiple', None) - # Convert a single Axes object to a list - # Ensure ax_array is always iterable - if n_groups == 1: - ax_array = [ax_array] - else: - ax_array = ax_array.flatten() - - for i, ax_i in enumerate(ax_array): - group_data = plot_data[plot_data[group_by] == - groups[i]][data_column] - hist_data = calculate_histogram(group_data, kwargs['bins']) - - sns.histplot(data=hist_data, x="bin_center", ax=ax_i, - weights='count', **kwargs) - # If plotting feature specify which layer - if feature: - ax_i.set_title(f'{groups[i]} with Layer: {layer}') + if not facet: + # 2) Grouped separately on different axes + fig, ax_array = plt.subplots( + n_groups, 1, figsize=(5, 5 * n_groups) + ) + + # Convert a single Axes object to a list + # Ensure ax_array is always iterable + if n_groups == 1: + ax_array = [ax_array] else: - ax_i.set_title(f'{groups[i]}') + ax_array = ax_array.flatten() - # Set axis scales if y_log_scale is True - if y_log_scale: - ax_i.set_yscale('log') + for i, ax_i in enumerate(ax_array): + group_data = plot_data[plot_data[group_by] == + groups[i]][data_column] + hist_data = calculate_histogram(group_data, kwargs['bins']) + + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, + weights='count', **kwargs) + + # If plotting feature specify which layer + if feature: + ax_i.set_title(f'{groups[i]} with Layer: {layer}') + else: + ax_i.set_title(f'{groups[i]}') + axs.append(ax_i) + + else: + # 3) Faceted by group on different axes using seaborn's FacetGrid. + # Compute max label length only when not explicitly provided. + facet_tick_max_chars = 0 + if not pd.api.types.is_numeric_dtype(plot_data[data_column]): + facet_tick_max_chars = compute_max_tick_label_length(plot_data[data_column]) + + # Derive facet geometry based on group count and layout hints + # Returned layout keys: facet_ncol, facet_height, facet_aspect + facet_layout = _derive_facet_geometry( + n_groups=n_groups, + facet_ncol=facet_ncol, + facet_fig_width=facet_fig_width, + facet_fig_height=facet_fig_height, + facet_tick_max_chars=facet_tick_max_chars, + facet_tick_rotation=facet_tick_rotation, + ) + + # Compute histogram data and shared bins for consistent plotting. + # For non-numeric data, shared_bins will be dropped intentionally. + hist_data, shared_bins = build_grouped_histogram_table( + plot_data, + data_column, + group_by, + groups, + bins=kwargs.pop('bins'), + ) + if pd.api.types.is_numeric_dtype(plot_data[data_column]): + kwargs['bins'] = shared_bins.tolist() + + # Create the FacetGrid for the histogram + hist = sns.FacetGrid( + hist_data, + col=group_by, + col_wrap=facet_layout['facet_ncol'], + height=facet_layout['facet_height'], + aspect=facet_layout['facet_aspect'], + sharex=True, + sharey=True, + ) + + # Map the histogram function to the grid + hist.map_dataframe( + sns.histplot, + x='bin_center', + weights='count', + **kwargs, + ) + + # Show tick labels on every facet while keeping shared axes. + for ax_i in hist.axes.flat: + ax_i.tick_params(axis='x', labelbottom=True) + ax_i.tick_params(axis='y', labelleft=True) + + # Set background color and grid for better readability + for ax_i in hist.axes.flat: + ax_i.set_facecolor('#f2f2f2') + ax_i.grid(True, which='major', axis='both') + + # Titles for each facet + hist.set_titles("{col_name}") + + # Adjust margins for readability across layouts. + hist.figure.subplots_adjust(left=.1, + top=0.9, + bottom=0.12, + hspace=0.35, + wspace=0.2) + + # Pass the figure and axes to the output for further customization + fig = hist.figure + fig.set_size_inches( + facet_fig_width or fig.get_figwidth(), + facet_fig_height or fig.get_figheight(), + ) + axs.extend(hist.axes.flat) - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column - ax_i.set_xlabel(xlabel) - - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - ax_i.set_ylabel(ylabel) - - axs.append(ax_i) else: + # 4) Ungrouped histogram (group_by=None) + if ax is None: + fig, ax = plt.subplots() + # Precompute histogram data for single plot hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): ax.set_xlim(hist_data['bin_left'].min(), - hist_data['bin_right'].max()) + hist_data['bin_right'].max()) sns.histplot( data=hist_data, @@ -735,35 +1181,39 @@ def calculate_histogram(data, bins, bin_edges=None): ax.set_title(f'Layer: {layer}') axs.append(ax) - # Set axis scales if y_log_scale is True - if y_log_scale: - ax.set_yscale('log') + # Determine axis labels based on scale and stat settings. + stat = kwargs.get('stat', 'count') + xlabel, ylabel = resolve_hist_axis_labels( + data_column=data_column, + x_log_scale=x_log_scale, + y_log_scale=y_log_scale, + stat=stat, + ) - # Adjust x-axis label if x_log_scale is True - if x_log_scale: - xlabel = f'log({data_column})' - else: - xlabel = data_column - ax.set_xlabel(xlabel) + axes = axs if isinstance(axs, (list, np.ndarray)) else [axs] + for ax in axes: + # Set axis scales if y_log_scale is True + if y_log_scale: + ax.set_yscale('log') + # For faceted plots, we set axis labels at the figure level only. + if facet: + ax.set_xlabel('') + ax.set_ylabel('') + else: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) - # Adjust y-axis label based on 'stat' parameter - stat = kwargs.get('stat', 'count') - ylabel_map = { - 'count': 'Count', - 'frequency': 'Frequency', - 'density': 'Density', - 'probability': 'Probability' - } - ylabel = ylabel_map.get(stat, 'Count') - if y_log_scale: - ylabel = f'log({ylabel})' - ax.set_ylabel(ylabel) + # Set a common x and y label for the entire figure if facet is True + if facet and fig is not None: + fig.supxlabel(xlabel) + fig.supylabel(ylabel) if len(axs) == 1: return {"fig": fig, "axs": axs[0], "df": hist_data} else: return {"fig": fig, "axs": axs, "df": hist_data} + def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. diff --git a/tests/templates/test_histogram_template.py b/tests/templates/test_histogram_template.py index 5a8e49e8..426e386f 100644 --- a/tests/templates/test_histogram_template.py +++ b/tests/templates/test_histogram_template.py @@ -50,6 +50,8 @@ def setUp(self) -> None: params = { "Upstream_Analysis": self.in_file, "Annotation": "cell_type", + "Group_by": "cell_type", + "Together": False, "Table_to_Visualize": "Original", "Feature_s_to_Plot": ["All"], "Figure_Title": "Test Histogram", @@ -59,6 +61,8 @@ def setUp(self) -> None: "Figure_DPI": 72, "Font_Size": 10, "Number_of_Bins": 20, + "Facet": True, + "Facet_Ncol": 1, "Output_Directory": self.tmp_dir.name, "outputs": { "dataframe": {"type": "file", "name": "dataframe.csv"}, diff --git a/tests/test_visualization/test_derive_facet_geometry.py b/tests/test_visualization/test_derive_facet_geometry.py new file mode 100644 index 00000000..17cebfad --- /dev/null +++ b/tests/test_visualization/test_derive_facet_geometry.py @@ -0,0 +1,95 @@ +import unittest + +from spac.visualization import _derive_facet_geometry + + +class TestDeriveFacetGeometry(unittest.TestCase): + def test_minimal_single_group_defaults(self): + """Single-group input should keep one column and default geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=1, + default_height=3.2, + default_aspect=1.25, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + + def test_auto_layout_uses_single_column_below_threshold(self): + """Check that auto layout selects 1 column when n_groups is at or below threshold.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol=None, + vertical_threshold=5, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + + def test_auto_layout_uses_sqrt_rule_above_threshold(self): + """Auto layout should use sqrt rule when n_groups is above threshold.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol=None, + vertical_threshold=3, + ) + + self.assertEqual(facet_layout["facet_ncol"], 3) + + def test_explicit_column_count_and_figure_size_hints_drive_geometry(self): + """Explicit facet_ncol should be used directly to compute geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=5, + facet_ncol=2, + vertical_threshold=5, + facet_fig_width=11, + facet_fig_height=4, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + self.assertAlmostEqual(facet_layout["facet_height"], 1.6) + self.assertAlmostEqual(facet_layout["facet_aspect"], 2.0) + + def test_single_figure_size_hint_falls_back_to_defaults(self): + """A one-sided size hint should not partially derive facet geometry.""" + facet_layout = _derive_facet_geometry( + n_groups=4, + facet_fig_width=11, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + + def test_none_inputs_fall_back_to_auto_and_default_geometry(self): + """Missing pre-normalized hints should use auto layout and defaults.""" + facet_layout = _derive_facet_geometry( + n_groups=3, + facet_ncol=None, + facet_fig_width=None, + facet_fig_height=None, + vertical_threshold=3, + default_height=3.2, + default_aspect=1.25, + ) + + self.assertEqual(facet_layout["facet_ncol"], 1) + self.assertEqual(facet_layout["facet_height"], 3.2) + self.assertEqual(facet_layout["facet_aspect"], 1.25) + + def test_explicit_column_count_is_clamped_to_group_count(self): + """Explicit facet_ncol should be clamped to n_groups if it exceeds it.""" + facet_layout = _derive_facet_geometry( + n_groups=2, + facet_ncol=10, + vertical_threshold=3, + ) + + self.assertEqual(facet_layout["facet_ncol"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_visualization/test_histogram.py b/tests/test_visualization/test_histogram.py index f8ba95ea..3a6c36e3 100644 --- a/tests/test_visualization/test_histogram.py +++ b/tests/test_visualization/test_histogram.py @@ -38,6 +38,52 @@ def setUp(self): # Create default layer self.adata.layers['Default'] = X.astype(np.float32) + def tearDown(self): + # Closes all figures to prevent memory issues + plt.close('all') + + def _make_many_groups_adata(self, n_groups=25): + """Create a compact AnnData fixture with one row per unique group.""" + X = np.arange(1, n_groups + 1, dtype=np.float32).reshape(-1, 1) + obs = pd.DataFrame( + {'many_groups': [f'g{i}' for i in range(n_groups)]}, + index=[f'cell_{i}' for i in range(n_groups)], + ) + var = pd.DataFrame(index=['marker1']) + return anndata.AnnData(X, obs=obs, var=var) + + def _make_long_label_facet_adata(self, include_short=False): + """Create small categorical facet fixtures for long-label geometry tests.""" + obs = { + 'annotation2': ['g1', 'g1', 'g1', 'g1', + 'g2', 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3', 'g3'], + } + if include_short: + obs['annotation_short'] = pd.Categorical( + ['A', 'B', 'C', 'D'] * 3, + categories=['A', 'B', 'C', 'D'], + ) + obs['annotation_long'] = pd.Categorical( + [ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ] * 3, + categories=[ + 'Activated T/B Cell', + 'Cytotoxic T Cell', + 'Follicular Dendritic Cell', + 'Regulatory T Cell', + ], + ) + return anndata.AnnData( + np.arange(1, 13, dtype=np.float32).reshape(-1, 1), + obs=pd.DataFrame(obs, index=[f'cell_{i}' for i in range(12)]), + var=pd.DataFrame(index=['marker1']), + ) + def test_both_feature_and_annotation(self): err_msg = ("Cannot pass both feature and annotation," " choose one.") @@ -53,8 +99,8 @@ def test_histogram_feature(self): bin_edges = list(np.linspace(0.5, 100.5, 101)) fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', bins=bin_edges ).values() @@ -79,7 +125,7 @@ def test_histogram_feature(self): def test_histogram_annotation(self): fig, ax, df = histogram( - self.adata, + self.adata, annotation='annotation1' ).values() total_annotation = len(self.adata.obs['annotation1']) @@ -147,8 +193,8 @@ def test_x_log_scale_transformation(self): ].flatten() fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', x_log_scale=True ).values() @@ -174,8 +220,8 @@ def test_negative_values_x_log_scale(self, mock_print): self.adata.X[0, self.adata.var_names.get_loc('marker1')] = -1 fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', x_log_scale=True ).values() @@ -192,21 +238,21 @@ def test_negative_values_x_log_scale(self, mock_print): def test_title(self): """Test that title changes based on 'layer' information""" fig, ax, df = histogram( - self.adata, + self.adata, feature='marker1' ).values() self.assertEqual(ax.get_title(), 'Layer: Original') fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', layer='Default' ).values() self.assertEqual(ax.get_title(), f'Layer: Default') fig, ax, df = histogram( - self.adata, - annotation='annotation1', + self.adata, + annotation='annotation1', layer='Default' ).values() self.assertEqual(ax.get_title(), '') @@ -214,17 +260,17 @@ def test_title(self): def test_y_log_scale_axis(self): """Test that y_log_scale sets y-axis to log scale.""" fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', y_log_scale=True ).values() self.assertEqual(ax.get_yscale(), 'log') def test_y_log_scale_label(self): """Test that y-axis label is updated when y_log_scale is True.""" - fig, ax, dfd = histogram( - self.adata, - feature='marker1', + fig, ax, df = histogram( + self.adata, + feature='marker1', y_log_scale=True ).values() self.assertEqual(ax.get_ylabel(), 'log(Count)') @@ -233,31 +279,31 @@ def test_y_axis_label_based_on_stat(self): """Test that y-axis label changes based on the 'stat' parameter.""" # Test default stat ('count') fig, ax, df = histogram( - self.adata, + self.adata, feature='marker1' ).values() self.assertEqual(ax.get_ylabel(), 'Count') # Test 'frequency' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='frequency' ).values() self.assertEqual(ax.get_ylabel(), 'Frequency') # Test 'density' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='density' ).values() self.assertEqual(ax.get_ylabel(), 'Density') # Test 'probability' stat fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', stat='probability' ).values() self.assertEqual(ax.get_ylabel(), 'Probability') @@ -298,6 +344,96 @@ def test_group_by_separate_with_y_log_scale(self): self.assertEqual(ax.get_yscale(), 'log') self.assertEqual(ax.get_ylabel(), 'log(Count)') + def test_group_by_max_groups_default_guardrail_rejects_excess_groups(self): + """Default threshold should reject excessive grouped facet plotting.""" + adata = self._make_many_groups_adata(n_groups=25) + with self.assertRaisesRegex(ValueError, "exceeds `max_groups`"): + histogram( + adata, + feature='marker1', + group_by='many_groups', + facet=True, + ) + + def test_group_by_max_groups_override_allows_grouped_plot(self): + """Custom positive max_groups should allow larger grouped plots.""" + n_groups = 25 + adata = self._make_many_groups_adata(n_groups=n_groups) + + fig, axs, _ = histogram( + adata, + feature='marker1', + group_by='many_groups', + facet=True, + max_groups=30, + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + self.assertEqual(len(axs), n_groups) + + def test_group_by_max_groups_unlimited_disables_guardrail(self): + """max_groups='unlimited' should disable grouped guardrail validation.""" + adata = self._make_many_groups_adata(n_groups=25) + + fig, ax, _ = histogram( + adata, + feature='marker1', + group_by='many_groups', + together=True, + max_groups='unlimited', + ).values() + self.assertIsNotNone(fig) + self.assertIsInstance(ax, mpl.axes.Axes) + + def test_group_by_max_groups_none_uses_default_threshold(self): + """Explicit None should resolve to default threshold behavior.""" + adata = self._make_many_groups_adata(n_groups=25) + with self.assertRaisesRegex(ValueError, "exceeds `max_groups`"): + histogram( + adata, + feature='marker1', + group_by='many_groups', + together=True, + max_groups=None, + ) + + def test_group_by_invalid_max_groups_raises_value_error(self): + """Invalid max_groups values should fail fast.""" + for value in [0, "bad", True]: + with self.subTest(max_groups=value): + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + max_groups=value, + ) + + def test_non_grouped_max_groups_is_ignored(self): + """Non-grouped calls should ignore grouped-only max_groups hints.""" + baseline_fig, baseline_ax, _ = histogram( + self.adata, + feature='marker1', + ).values() + + fig, ax, _ = histogram( + self.adata, + feature='marker1', + max_groups=0, + ).values() + self.assertAlmostEqual( + fig.get_figwidth(), + baseline_fig.get_figwidth(), + places=6, + ) + self.assertAlmostEqual( + fig.get_figheight(), + baseline_fig.get_figheight(), + places=6, + ) + self.assertGreater(len(ax.patches), 0) + self.assertEqual(len(ax.patches), len(baseline_ax.patches)) + def test_overlay_options(self): fig, ax, df = histogram( self.adata, @@ -353,6 +489,7 @@ def test_layer(self): ) def test_ax_passed_as_argument(self): + # Supported mode 1: single-axes histogram with external ax. fig, ax = plt.subplots() returned_fig, returned_ax, df = histogram( self.adata, @@ -365,8 +502,53 @@ def test_ax_passed_as_argument(self): # Check that the passed fig is the one that is returned self.assertIs(fig, returned_fig) - # Check that returned_ax is an Axes object + # Supported mode 2: grouped+together histogram with external ax. + fig_grouped, ax_grouped = plt.subplots() + returned_grouped_fig, returned_grouped_ax, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + ax=ax_grouped, + ).values() + + self.assertIs(fig_grouped, returned_grouped_fig) + self.assertIs(ax_grouped, returned_grouped_ax) + + # Check that returned axes are valid Axes objects. self.assertIsInstance(returned_ax, mpl.axes.Axes) + self.assertIsInstance(returned_grouped_ax, mpl.axes.Axes) + + def test_external_ax_guardrail_modes(self): + # Reject grouped-separate mode with external ax. + fig_1, ax_1 = plt.subplots() + with self.assertRaisesRegex( + ValueError, + "External ax is only supported for single-axes histogram" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=False, + ax=ax_1, + ) + + # Reject facet mode with external ax. + fig_2, ax_2 = plt.subplots() + with self.assertRaisesRegex( + ValueError, + "External ax is only supported for single-axes histogram" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ax=ax_2, + ) + + # Positive external-ax modes are covered in test_ax_passed_as_argument. def test_default_first_feature(self): with self.assertWarns(UserWarning) as warning: @@ -388,8 +570,8 @@ def test_histogram_feature_integer_bins(self): custom_bins = 10 # Specify number of bins as an integer fig, ax, df = histogram( - self.adata, - feature='marker1', + self.adata, + feature='marker1', bins=custom_bins ).values() @@ -401,17 +583,592 @@ def test_histogram_feature_integer_bins(self): self.assertIsInstance(ax, mpl.axes.Axes) def test_default_bins_calculation(self): + """No bins argument should use Rice-rule fallback.""" + expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) + # No bins parameter passed fig, ax, df = histogram(self.adata, feature='marker1').values() + self.assertEqual(len(ax.patches), expected_bins) + self.assertEqual(len(df), expected_bins) + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center'}, + ) - # Count the number of bins - bars = ax.patches - n_bins = len(bars) - - # Validate the number of bins based on default bin calculation logic - # Using 2 * (n ** 1/3) heuristic for default bins + def test_default_like_bins_calculation(self): + """Default-like bins values should use Rice-rule fallback.""" expected_bins = max(int(2 * (self.adata.shape[0] ** (1 / 3))), 1) - self.assertEqual(n_bins, expected_bins) + + for bins_value in [None, 'auto', 'none', '']: + with self.subTest(bins=bins_value): + fig, ax, df = histogram( + self.adata, + feature='marker1', + bins=bins_value, + ).values() + + self.assertEqual(len(ax.patches), expected_bins) + self.assertEqual(len(df), expected_bins) + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center'}, + ) + + def test_grouped_separate_ignores_multiple(self): + """Grouped separate mode should ignore irrelevant multiple settings.""" + fig, axs, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=False, + multiple="fill", + ).values() + self.assertIsInstance(fig, mpl.figure.Figure) + self.assertIsInstance(axs, list) + self.assertEqual( + len(axs), + self.adata.obs["annotation2"].dropna().nunique(), + ) + + def test_facet_requires_group_by(self): + """Test that facet mode requires group_by parameter""" + with self.assertRaisesRegex( + ValueError, + "group_by must be specified when facet=True." + ): + histogram( + self.adata, + feature='marker1', + facet=True, + ) + + def test_facet_conflicts_with_together_true(self): + """Test that facet mode conflicts with together=True""" + with self.assertRaisesRegex( + ValueError, + "Cannot use together=True with facet=True" + ): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + together=True, + facet=True, + ) + + def test_facet_plot_smoke_and_structure(self): + """Facet path returns expected structure and plotted content.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ).values() + + # Basic structure checks + self.assertIsNotNone(fig) + self.assertIsNotNone(df) + self.assertIsInstance(ax, (list, np.ndarray), + "Facet output should be a multi-axis collection.") + + # Check the number of facet axes matches group count + unique_groups = self.adata.obs['annotation2'].dropna().unique() + self.assertEqual(len(ax), len(unique_groups), + f"Expected {len(unique_groups)}" + f" facet plots, got {len(ax)}.") + + # Lightweight bar-level presence checks only. + for i, axis in enumerate(ax): + self.assertGreater( + len(axis.patches), + 0, + f"Facet {i} should contain at least one bar patch." + ) + + def test_facet_plot_titles_and_label_policy(self): + """Facet titles map to groups and labels follow figure-level policy.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + ).values() + + # Ensure ax is iterable for consistent handling + ax = ax if isinstance(ax, (list, np.ndarray)) else [ax] + unique_groups = self.adata.obs['annotation2'].dropna().unique() + + # Titles must map to expected groups and labels must be per-figure. + for i, axis in enumerate(ax): + title = axis.get_title() + self.assertTrue(title, f"Facet {i} is missing a title.") + self.assertTrue(any(str(group) in title + for group in unique_groups), + f"Title '{title}' does not contain" + f"any expected group names.") + self.assertEqual(axis.get_xlabel(), '', + f"Facet {i} x-label should be empty.") + self.assertEqual(axis.get_ylabel(), '', + f"Facet {i} y-label should be empty.") + + # Figure-level labels should be set in facet mode. + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'marker1') + self.assertEqual(fig._supylabel.get_text(), 'Count') + + def test_facet_plot_density_stat_label_policy(self): + """Facet figure-level y label reflects non-default stat mapping.""" + fig, ax, df = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + stat='density', + ).values() + + # Check that figure-level y label reflects 'density' stat when specified. + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supylabel.get_text(), 'Density') + + def test_facet_plot_categorical_annotation(self): + """Test facet mode with categorical annotations""" + fig, axs, df = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + + # Ensure axs is iterable for consistent handling + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + expected_groups = self.adata.obs['annotation2'].dropna().nunique() + self.assertEqual(len(axs), expected_groups) + + # Check that data is plotted in each facet + for axis in axs: + self.assertGreater(len(axis.patches), 0) + + # Check figure-level labels are set appropriately + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'annotation1') + self.assertEqual(fig._supylabel.get_text(), 'Count') + + def test_facet_plot_numeric_annotation(self): + """Facet mode should support numeric annotations sourced from obs.""" + adata = self.adata.copy() + adata.obs['annotation_numeric'] = np.arange( + adata.n_obs, + dtype=np.float32, + ) + + fig, axs, _ = histogram( + adata, + annotation='annotation_numeric', + group_by='annotation2', + facet=True, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + expected_groups = adata.obs['annotation2'].dropna().nunique() + self.assertIsNotNone(fig) + self.assertEqual(len(axs), expected_groups) + + for axis in axs: + self.assertGreater(len(axis.patches), 0) + + self.assertIsNotNone(fig._supxlabel) + self.assertIsNotNone(fig._supylabel) + self.assertEqual(fig._supxlabel.get_text(), 'annotation_numeric') + self.assertEqual(fig._supylabel.get_text(), 'Count') + + def test_facet_ncol_layout_hints(self): + """Facet ncol supports positive int and documented auto behavior.""" + # Explicit two-column layout should create two facet columns. + fig, axs, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol=2, + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + x_positions = {round(axis.get_position().x0, 4) for axis in axs} + self.assertGreaterEqual(len(x_positions), 2) + + # Documented default-like input should use auto layout (one column for 3 groups). + fig, axs, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol='auto', + ).values() + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + x_positions = {round(axis.get_position().x0, 4) for axis in axs} + self.assertEqual(len(x_positions), 1) + + # Invalid values should fail fast. + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol='bad', + ) + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_ncol=0, + ) + + def test_facet_figure_size_hints(self): + """Facet figure-size hints should accept valid values and sanitize invalid ones.""" + # Check that valid figure size hints are applied to the facet figure. + fig, _, _ = histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=11, + facet_fig_height=3.5, + ).values() + self.assertAlmostEqual(fig.get_figwidth(), 11.0, places=2) + self.assertAlmostEqual(fig.get_figheight(), 3.5, places=2) + + # Invalid hints should fail fast. + for width, height in [('wide', 'tall'), (-1, 0)]: + with self.subTest(facet_fig_width=width, facet_fig_height=height): + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=width, + facet_fig_height=height, + ) + + def test_facet_figure_size_hints_require_pair(self): + """One-sided facet figure-size hints should raise a ValueError.""" + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_width=11, + ) + with self.assertRaises(ValueError): + histogram( + self.adata, + feature='marker1', + group_by='annotation2', + facet=True, + facet_fig_height=3.5, + ) + + def test_facet_tick_rotation_zero_matches_default_behavior(self): + """Explicit zero rotation should match omitted rotation behavior.""" + fig_default, _, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + fig_zero, _, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + facet_tick_rotation=0, + ).values() + + self.assertAlmostEqual(fig_default.get_figwidth(), fig_zero.get_figwidth(), places=6) + self.assertAlmostEqual(fig_default.get_figheight(), fig_zero.get_figheight(), places=6) + + def test_facet_long_label_geometry_adjustment_without_size_hints(self): + """Long rotated categorical labels should increase default facet geometry.""" + adata = self._make_long_label_facet_adata(include_short=True) + + fig_short, _, _ = histogram( + adata, + annotation='annotation_short', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=45, + ).values() + fig_long, _, _ = histogram( + adata, + annotation='annotation_long', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=45, + ).values() + + self.assertGreater(fig_long.get_figwidth(), fig_short.get_figwidth()) + self.assertGreater(fig_long.get_figheight(), fig_short.get_figheight()) + + def test_facet_long_label_geometry_respects_explicit_size_hints(self): + """Explicit facet figure-size hints should remain authoritative.""" + adata = self._make_long_label_facet_adata() + + fig, _, _ = histogram( + adata, + annotation='annotation_long', + group_by='annotation2', + facet=True, + facet_ncol=2, + facet_tick_rotation=60, + facet_fig_width=10, + facet_fig_height=4, + ).values() + + self.assertAlmostEqual(fig.get_figwidth(), 10.0, places=2) + self.assertAlmostEqual(fig.get_figheight(), 4.0, places=2) + + def test_non_facet_layout_hints_are_ignored(self): + """Non-facet calls should ignore all facet-only layout hints.""" + baseline_fig, baseline_ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + ).values() + + for hint_kwargs in ( + {'facet_fig_width': 8, 'facet_fig_height': 5}, + { + 'facet_ncol': 0, + 'facet_tick_rotation': 'bad', + 'facet_fig_width': 'wide', + 'facet_fig_height': 'tall', + }, + ): + with self.subTest(hints=hint_kwargs): + fig, ax, _ = histogram( + self.adata, + feature='marker1', + facet=False, + **hint_kwargs, + ).values() + self.assertAlmostEqual( + fig.get_figwidth(), + baseline_fig.get_figwidth(), + places=6, + ) + self.assertAlmostEqual( + fig.get_figheight(), + baseline_fig.get_figheight(), + places=6, + ) + self.assertGreater(len(ax.patches), 0) + self.assertEqual(len(ax.patches), len(baseline_ax.patches)) + + def test_facet_plot_shared_bins_consistency_numeric(self): + """Numeric facets keep shared bins for int/default-like bins inputs.""" + # Unbalanced groups: each group occupies only part of the global range. + # If bins are computed per-group (bad path), centers/ticks may diverge. + adata = anndata.AnnData( + np.array([[0.0], [1.0], [2.0], [10.0], [11.0], [12.0]], + dtype=np.float32), + obs=pd.DataFrame( + {'annotation2': ['g1', 'g1', 'g1', 'g2', 'g2', 'g2']}, + index=[f'cell_{i}' for i in range(6)], + ), + var=pd.DataFrame(index=['marker1']), + ) + + # Test one explicit and one default-like bins path. + for bins_value in [4, None]: + with self.subTest(bins=bins_value): + fig, axs, df = histogram( + adata, + feature='marker1', + group_by='annotation2', + facet=True, + bins=bins_value, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + + # Capture the bin centers and ticks from the first facet + first_xlim = np.round(np.array(axs[0].get_xlim()), 6) + first_xticks = np.round(np.array(axs[0].get_xticks()), 6) + first_yticks = np.round(np.array(axs[0].get_yticks()), 6) + first_centers = np.round( + np.array([ + patch.get_x() + patch.get_width() / 2 + for patch in axs[0].patches + ]), + 6 + ) + + # Check that all facets have the same bin centers and ticks + for axis in axs[1:]: + centers = np.round( + np.array([ + patch.get_x() + patch.get_width() / 2 + for patch in axis.patches + ]), + 6 + ) + self.assertTrue( + np.array_equal(centers, first_centers), + "Facet numeric bin centers should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_xlim()), 6), first_xlim), + "Facet numeric x-limits should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_xticks()), 6), first_xticks), + "Facet numeric x-ticks should remain shared across panels." + ) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), + "Facet numeric y-ticks should remain shared across panels." + ) + + # Check that the returned DataFrame has expected structure and content + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center', 'annotation2'}, + ) + self.assertNotIn('marker1', df.columns) + self.assertEqual(set(df['annotation2']), {'g1', 'g2'}) + self.assertEqual(df['count'].sum(), adata.n_obs) + grouped_edges = [ + ( + np.round(group_df['bin_left'].to_numpy(), 6), + np.round(group_df['bin_right'].to_numpy(), 6), + ) + for _, group_df in df.groupby('annotation2') + ] + self.assertEqual(len(grouped_edges), 2) + self.assertTrue(np.array_equal(grouped_edges[0][0], grouped_edges[1][0])) + self.assertTrue(np.array_equal(grouped_edges[0][1], grouped_edges[1][1])) + + def test_facet_plot_shared_bins_consistency_categorical(self): + """Facet categorical bins stay aligned even with missing labels.""" + # Build unbalanced facet groups where some labels are missing per group. + adata = anndata.AnnData( + np.arange(1, 10, dtype=np.float32).reshape(-1, 1), + obs=pd.DataFrame( + { + 'annotation1': pd.Categorical( + ['A', 'A', 'B', 'A', 'C', 'C', 'B', 'C', 'A'], + categories=['A', 'B', 'C'], + ), + 'annotation2': ['g1', 'g1', 'g1', + 'g2', 'g2', 'g2', + 'g3', 'g3', 'g3'], + }, + index=[f'cell_{i}' for i in range(9)], + ), + var=pd.DataFrame(index=['marker1']), + ) + + fig, axs, df = histogram( + adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + ).values() + + axs = axs if isinstance(axs, (list, np.ndarray)) else [axs] + + # Guardrail: this fixture must include missing labels per group. + group_uniques = adata.obs.groupby('annotation2')['annotation1'].nunique() + self.assertTrue(any(group_uniques < 3)) + + # Check that bin centers are shared across facets + global_centers = set() + for axis in axs: + global_centers.update( + np.round( + [patch.get_x() + patch.get_width() / 2 + for patch in axis.patches], + 6, + ) + ) + self.assertEqual( + len(global_centers), + 3, + "Expected 3 categorical slots (A/B/C) to be preserved globally." + ) + + # Check that ticks are shared across facets + first_xticks = np.round(axs[0].get_xticks(), 6) + first_yticks = np.round(np.array(axs[0].get_yticks()), 6) + for axis in axs[1:]: + self.assertTrue(np.array_equal(np.round(axis.get_xticks(), 6), first_xticks)) + self.assertTrue( + np.array_equal(np.round(np.array(axis.get_yticks()), 6), first_yticks), + "Facet categorical y-ticks should remain shared across panels." + ) + + # Check that the returned DataFrame has expected structure and content + self.assertEqual( + set(df.columns), + {'count', 'bin_left', 'bin_right', 'bin_center', 'annotation2'}, + ) + self.assertNotIn('annotation1', df.columns) + self.assertEqual(set(df['annotation2']), {'g1', 'g2', 'g3'}) + self.assertEqual(df['count'].sum(), adata.n_obs) + self.assertEqual( + {str(value) for value in df['bin_center'].unique()}, + {'A', 'B', 'C'}, + ) + for _, group_df in df.groupby('annotation2'): + self.assertEqual( + {str(value) for value in group_df['bin_center']}, + {'A', 'B', 'C'}, + ) + + def test_facet_plot_categorical_annotation_ignores_bins(self): + """Facet categorical annotations should ignore custom bins values.""" + fig_small, axs_small, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + bins=2, + ).values() + fig_large, axs_large, _ = histogram( + self.adata, + annotation='annotation1', + group_by='annotation2', + facet=True, + bins=99, + ).values() + + axs_small = axs_small if isinstance(axs_small, (list, np.ndarray)) else [axs_small] + axs_large = axs_large if isinstance(axs_large, (list, np.ndarray)) else [axs_large] + + self.assertEqual(len(axs_small), len(axs_large)) + + axis_small = axs_small[0] + axis_large = axs_large[0] + small_centers = [ + patch.get_x() + patch.get_width() / 2 + for patch in axis_small.patches + ] + large_centers = [ + patch.get_x() + patch.get_width() / 2 + for patch in axis_large.patches + ] + self.assertEqual(small_centers, large_centers) + self.assertEqual( + [tick.get_text() for tick in axis_small.get_xticklabels()], + [tick.get_text() for tick in axis_large.get_xticklabels()], + ) if __name__ == '__main__':