From 00caac0b6a91e1d5a239c873ce340d67e3f36894 Mon Sep 17 00:00:00 2001 From: Clare72 Date: Mon, 11 May 2026 17:32:30 +0100 Subject: [PATCH] improved class aggregation and tests --- .../test_downstream_class_connectivity.py | 99 +++++ src/test/test_upstream_class_connectivity.py | 95 +++++ src/vfbquery/vfb_queries.py | 344 ++++++++++++++---- 3 files changed, 467 insertions(+), 71 deletions(-) diff --git a/src/test/test_downstream_class_connectivity.py b/src/test/test_downstream_class_connectivity.py index 2483046..df7286b 100644 --- a/src/test/test_downstream_class_connectivity.py +++ b/src/test/test_downstream_class_connectivity.py @@ -108,6 +108,105 @@ def test_empty_class_returns_empty_dataframe(self): assert df.empty +class TestDownstreamClassConnectivityHierarchyRollup: + """Regression tests for the partner-side hierarchy rollup behaviour: + connections to a child class also count toward each ancestor class within + the Neuron subtree, without double-counting under FBbt multi-inheritance. + """ + + @pytest.fixture(scope='class') + def result(self): + return get_downstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True, + ) + + @pytest.mark.integration + def test_parent_class_appears_with_sensible_counts(self, result): + """A row keyed on a parent class should have connected_n at least as + large as any of its descendant rows (set-union semantics) and at most + the sum of descendant connected_n (no double-counting beyond what + multi-inheritance forces). + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + ids = [r["id"] for r in rows] + assert ids, "Expected at least one row to test against" + + # Find any (parent, child) pair among the row ids. + q = ( + "MATCH (p:Class)<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE p.short_form IN %s AND c.short_form IN %s " + "RETURN p.short_form AS parent, c.short_form AS child LIMIT 1" + % (ids, ids) + ) + pairs = get_dict_cursor()(vc.nc.commit_list([q])) + if not pairs: + pytest.skip("No parent/child pair among result rows for this class") + + parent_id = pairs[0]["parent"] + child_id = pairs[0]["child"] + parent_row = next(r for r in rows if r["id"] == parent_id) + # Sum connected_n across all descendant rows (not just the one returned). + desc_q = ( + "MATCH (p:Class {short_form: '%s'})<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS descs" + % (parent_id, ids) + ) + desc_rows = get_dict_cursor()(vc.nc.commit_list([desc_q])) + descendant_ids = desc_rows[0]["descs"] if desc_rows else [child_id] + descendant_rows = [r for r in rows if r["id"] in descendant_ids] + max_child = max(r["connected_n"] for r in descendant_rows) + sum_child = sum(r["connected_n"] for r in descendant_rows) + assert parent_row["connected_n"] >= max_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be >= max descendant connected_n={max_child}" + ) + assert parent_row["connected_n"] <= sum_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be <= sum of descendant connected_n={sum_child}" + ) + + @pytest.mark.integration + def test_total_n_is_constant_across_rows(self, result): + """`total_n` is the queried-side instance count and must be the same + for every output row (regression for the previous summed-across- + subclasses value). + """ + rows = result["rows"] + assert rows, "Expected at least one row" + total_ns = {r["total_n"] for r in rows} + assert len(total_ns) == 1, ( + f"Expected total_n to be constant across rows, got: {total_ns}" + ) + assert next(iter(total_ns)) > 0 + + @pytest.mark.integration + def test_no_rows_above_neuron_root(self, result): + """The partner-side ancestor walk should stop at the Neuron class + (FBbt_00005106). No row id should be a class outside the Neuron + subtree. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor, NEURON_ROOT_SHORT_FORM + + ids = [r["id"] for r in result["rows"]] + assert ids, "Expected at least one row" + q = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS in_neuron" + % (NEURON_ROOT_SHORT_FORM, ids) + ) + result_rows = get_dict_cursor()(vc.nc.commit_list([q])) + in_neuron = set(result_rows[0]["in_neuron"]) if result_rows else set() + offenders = [i for i in ids if i not in in_neuron] + assert not offenders, ( + f"Found {len(offenders)} row(s) outside the Neuron subtree: " + f"{offenders[:5]}" + ) + + class TestDownstreamClassConnectivitySchema: def test_schema_generation(self): schema = DownstreamClassConnectivity_to_schema( diff --git a/src/test/test_upstream_class_connectivity.py b/src/test/test_upstream_class_connectivity.py index ae59e9f..7cc538b 100644 --- a/src/test/test_upstream_class_connectivity.py +++ b/src/test/test_upstream_class_connectivity.py @@ -108,6 +108,101 @@ def test_empty_class_returns_empty_dataframe(self): assert df.empty +class TestUpstreamClassConnectivityHierarchyRollup: + """Regression tests for the partner-side hierarchy rollup behaviour: + connections from a child class also count toward each ancestor class + within the Neuron subtree, without double-counting under FBbt + multi-inheritance. + """ + + @pytest.fixture(scope='class') + def result(self): + return get_upstream_class_connectivity( + TEST_CLASS, return_dataframe=False, force_refresh=True, + ) + + @pytest.mark.integration + def test_parent_class_appears_with_sensible_counts(self, result): + """A row keyed on a parent class should have connected_n at least as + large as any of its descendant rows (set-union semantics) and at most + the sum of descendant connected_n. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + ids = [r["id"] for r in rows] + assert ids, "Expected at least one row to test against" + + q = ( + "MATCH (p:Class)<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE p.short_form IN %s AND c.short_form IN %s " + "RETURN p.short_form AS parent, c.short_form AS child LIMIT 1" + % (ids, ids) + ) + pairs = get_dict_cursor()(vc.nc.commit_list([q])) + if not pairs: + pytest.skip("No parent/child pair among result rows for this class") + + parent_id = pairs[0]["parent"] + parent_row = next(r for r in rows if r["id"] == parent_id) + desc_q = ( + "MATCH (p:Class {short_form: '%s'})<-[:SUBCLASSOF*1..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS descs" + % (parent_id, ids) + ) + desc_rows = get_dict_cursor()(vc.nc.commit_list([desc_q])) + descendant_ids = desc_rows[0]["descs"] if desc_rows else [] + descendant_rows = [r for r in rows if r["id"] in descendant_ids] + max_child = max(r["connected_n"] for r in descendant_rows) + sum_child = sum(r["connected_n"] for r in descendant_rows) + assert parent_row["connected_n"] >= max_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be >= max descendant connected_n={max_child}" + ) + assert parent_row["connected_n"] <= sum_child, ( + f"Parent {parent_id} connected_n={parent_row['connected_n']} should " + f"be <= sum of descendant connected_n={sum_child}" + ) + + @pytest.mark.integration + def test_total_n_is_constant_across_rows(self, result): + """`total_n` is the queried-side instance count and must be the same + for every output row. + """ + rows = result["rows"] + assert rows, "Expected at least one row" + total_ns = {r["total_n"] for r in rows} + assert len(total_ns) == 1, ( + f"Expected total_n to be constant across rows, got: {total_ns}" + ) + assert next(iter(total_ns)) > 0 + + @pytest.mark.integration + def test_no_rows_above_neuron_root(self, result): + """The partner-side ancestor walk should stop at the Neuron class + (FBbt_00005106). No row id should be a class outside the Neuron + subtree. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor, NEURON_ROOT_SHORT_FORM + + ids = [r["id"] for r in result["rows"]] + assert ids, "Expected at least one row" + q = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class) " + "WHERE c.short_form IN %s " + "RETURN collect(DISTINCT c.short_form) AS in_neuron" + % (NEURON_ROOT_SHORT_FORM, ids) + ) + result_rows = get_dict_cursor()(vc.nc.commit_list([q])) + in_neuron = set(result_rows[0]["in_neuron"]) if result_rows else set() + offenders = [i for i in ids if i not in in_neuron] + assert not offenders, ( + f"Found {len(offenders)} row(s) outside the Neuron subtree: " + f"{offenders[:5]}" + ) + + class TestUpstreamClassConnectivitySchema: def test_schema_generation(self): schema = UpstreamClassConnectivity_to_schema( diff --git a/src/vfbquery/vfb_queries.py b/src/vfbquery/vfb_queries.py index ee794c3..b842763 100644 --- a/src/vfbquery/vfb_queries.py +++ b/src/vfbquery/vfb_queries.py @@ -3129,50 +3129,256 @@ def _fetch_connectivity_entries(short_form: str, solr_field: str): return all_entries -def _merge_connectivity_rows(entries, partner_key, partner_id_key, partner_label_key): - """Merge connectivity entries by partner class, summing statistics. - - Returns a list of merged row dicts ready for DataFrame / dict output. - ``partner_key`` is the output column name (e.g. 'downstream_class'), - ``partner_id_key`` / ``partner_label_key`` are the keys inside - ``class_connectivity`` to read partner id and label from. - """ - # Accumulate by partner class id - merged = {} # partner_id -> {label, total_n, connected_n, pw, tw} - for entry in entries: - cc = entry.get('class_connectivity', {}) +def _num(v): + """Coerce a value to a number, defaulting to 0.""" + try: + return float(v) + except (TypeError, ValueError): + return 0 + + +# Root class for partner-side ancestor walk. Edges contributing to a partner +# class row require the partner instance to be (transitively) an instance of +# that class, with NEURON_ROOT_SHORT_FORM bounding the walk to avoid generic +# anatomy classes. +NEURON_ROOT_SHORT_FORM = 'FBbt_00005106' + + +def _get_partner_class_ancestors(direct_partner_ids, neuron_root=NEURON_ROOT_SHORT_FORM): + """Walk SUBCLASSOF up from each direct partner class to ``neuron_root``. + + Returns ``(class_ids, labels)`` where ``class_ids`` is the union of every + direct partner plus its ancestors that are also subclasses of + ``neuron_root``. ``labels`` maps id -> human-readable label. + """ + if not direct_partner_ids: + return set(), {} + direct_list = sorted(direct_partner_ids) + query = ( + "MATCH (root:Class {short_form: '%s'})<-[:SUBCLASSOF*0..]-(c:Class)" + "<-[:SUBCLASSOF*0..]-(d:Class) " + "WHERE d.short_form IN %s " + "RETURN DISTINCT c.short_form AS id, c.label AS label" + % (neuron_root, direct_list) + ) + try: + results = vc.nc.commit_list([query]) + rows = get_dict_cursor()(results) + except Exception as e: + print(f"Partner class hierarchy query failed: {e}") + # Fall back to direct partners only so we still produce some output. + return set(direct_partner_ids), {pid: pid for pid in direct_partner_ids} + ids = set() + labels = {} + for row in rows: + cid = row.get('id') + if not cid: + continue + ids.add(cid) + labels[cid] = row.get('label') or cid + return ids, labels + + +def _build_partner_instance_class_membership(class_ids): + """Build ``instance_id -> set(class_ids)`` for the supplied partner + classes, using a single Cypher round-trip with SUBCLASSOF closure. + + Multi-typed instances appear in multiple class sets, which is exactly what + we need for set-union aggregation across hierarchy levels. Doing this with + one batched query rather than per-class avoids hundreds of round-trips + when ``class_ids`` is large. + """ + if not class_ids: + return {} + class_list = sorted(class_ids) + query = ( + "MATCH (c:Class)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-" + "(n:Individual:has_neuron_connectivity) " + "WHERE c.short_form IN %s " + "RETURN c.short_form AS cid, collect(DISTINCT n.short_form) AS iids" + % class_list + ) + try: + results = vc.nc.commit_list([query]) + rows = get_dict_cursor()(results) + except Exception as e: + print(f"Partner class membership query failed: {e}") + return {} + instance_to_classes = {} + for row in rows: + cid = row.get('cid') + for iid in row.get('iids') or []: + instance_to_classes.setdefault(iid, set()).add(cid) + return instance_to_classes + + +def _bulk_fetch_per_instance_connectivity(instance_ids): + """Bulk-fetch cached ``neuron_neuron_connectivity_query`` results from the + Solr cache collection for the given instance IDs. + + Returns ``(found, missing)`` where ``found`` maps instance_id -> + list-of-partner-rows and ``missing`` lists instances that had no cache hit. + Tries the ``_dataframe_False`` variant first (rows are easy to parse), + then falls back to ``_dataframe_True`` for any instances still missing. + """ + if not instance_ids: + return {}, [] + instance_ids = list(instance_ids) + found = {} + prefix = 'vfb_query_neuron_neuron_connectivity_query_' + for suffix in ('_dataframe_False', '_dataframe_True'): + remaining = [i for i in instance_ids if i not in found] + if not remaining: + break + cache_ids = [f'{prefix}{i}{suffix}' for i in remaining] + try: + results = vfb_solr.search( + q='*:*', + fq='{!terms f=id}' + ','.join(cache_ids), + fl='id,cache_data', + rows=len(cache_ids), + ) + except Exception as e: + print(f"Bulk per-instance cache fetch failed ({suffix}): {e}") + continue + for doc in results.docs: + doc_id = doc.get('id') + cache_data_raw = doc.get('cache_data') + if isinstance(cache_data_raw, list): + cache_data_raw = cache_data_raw[0] if cache_data_raw else None + if not doc_id or not cache_data_raw: + continue + if not (doc_id.startswith(prefix) and doc_id.endswith(suffix)): + continue + iid = doc_id[len(prefix):-len(suffix)] + try: + cached = json.loads(cache_data_raw) + result = cached.get('result') + if isinstance(result, str): + result = json.loads(result) + if isinstance(result, dict): + rows = result.get('rows', []) + elif isinstance(result, list): + rows = result + else: + rows = [] + found[iid] = rows + except Exception as e: + print(f"Failed to parse cached connectivity for {iid}: {e}") + missing = [i for i in instance_ids if i not in found] + return found, missing + + +def _aggregate_class_connectivity(short_form, direction, + neuron_root=NEURON_ROOT_SHORT_FORM): + """Aggregate class-level partner connectivity correctly under FBbt + multi-inheritance, using set-union over instance memberships. + + ``direction`` is ``'downstream'`` (partner = downstream of queried class) + or ``'upstream'``. Returns a list of row dicts with the same fields the + previous summation-based implementation produced. + """ + from collections import defaultdict + + # 1. Queried-side instances (subclass closure via Neo4j — Owlery's + # get_instances has been observed to hang for some classes, while a + # SUBCLASSOF traversal in Cypher is fast and equivalent here). + queried_q = ( + "MATCH (n:Individual:has_neuron_connectivity)-[:INSTANCEOF]->" + "(:Class)-[:SUBCLASSOF*0..]->(:Class {short_form: '%s'}) " + "RETURN DISTINCT n.short_form AS sf" % short_form + ) + try: + results = vc.nc.commit_list([queried_q]) + rows = get_dict_cursor()(results) + queried_instances = [r['sf'] for r in rows if r.get('sf')] + except Exception as e: + print(f"Queried-side instance query failed for {short_form}: {e}") + return [] + if not queried_instances: + return [] + queried_instance_set = set(queried_instances) + total_n_queried = len(queried_instance_set) + + # 2. Per-instance edges from cache. Cache misses are skipped with a warning; + # the resulting connected_n / pairwise / total_weight will be a slight + # underestimate when this happens. + found_edges, missing = _bulk_fetch_per_instance_connectivity(queried_instances) + if missing: + print( + f"Warning: per-instance connectivity cache missing for " + f"{len(missing)}/{total_n_queried} instances of {short_form}; " + f"those will be skipped (results may be a slight underestimate)." + ) + if not found_edges: + return [] + + weight_key = 'outputs' if direction == 'downstream' else 'inputs' + + # 3. Direct partner classes from the existing class-level connectivity + # field (already cached) — used as the seed set for the partner-side + # ancestor walk. + solr_field = ( + 'downstream_connectivity_query' if direction == 'downstream' + else 'upstream_connectivity_query' + ) + class_entries = _fetch_connectivity_entries(short_form, solr_field) + direct_partner_ids = set() + for entry in class_entries: obj = entry.get('object', {}) - pid = obj.get('short_form', cc.get(partner_id_key, '')) - plabel = obj.get('label', cc.get(partner_label_key, '')) - if not pid: + pid = obj.get('short_form') + if pid: + direct_partner_ids.add(pid) + + # 4. Walk SUBCLASSOF up from each direct partner to ``neuron_root``. + partner_class_ids, class_labels = _get_partner_class_ancestors( + direct_partner_ids, neuron_root, + ) + if not partner_class_ids: + return [] + + # 5. Build partner_instance_id -> {class_ids it belongs to}, restricted + # to in-scope partner classes. + instance_to_classes = _build_partner_instance_class_membership(partner_class_ids) + + # 6. Aggregate edges into per-class buckets via set-union semantics. + buckets = defaultdict(lambda: { + 'edges': set(), 'weight_sum': 0.0, 'connected_n1': set(), + }) + for n1, partner_rows in found_edges.items(): + if n1 not in queried_instance_set: continue - if pid not in merged: - merged[pid] = { - 'label': plabel, - 'total_n': 0, - 'connected_n': 0, - 'pairwise_connections': 0, - 'total_weight': 0, - } - m = merged[pid] - m['total_n'] += _num(cc.get('total_upstream_count', 0)) - m['connected_n'] += _num(cc.get('connected_upstream_count', 0)) - m['pairwise_connections'] += _num(cc.get('pairwise_connections', 0)) - m['total_weight'] += _num(cc.get('total_weight', 0)) + for prow in partner_rows or []: + n2 = prow.get('id') + w = prow.get(weight_key) + if not n2 or not w: + continue + try: + w_num = float(w) + except (TypeError, ValueError): + continue + if w_num <= 0: + continue + for c in instance_to_classes.get(n2, ()): + b = buckets[c] + b['edges'].add((n1, n2)) + b['weight_sum'] += w_num + b['connected_n1'].add(n1) + # 7. Emit one row per partner class that received at least one edge. rows = [] - for pid, m in merged.items(): - total_n = m['total_n'] - connected_n = m['connected_n'] - pw = m['pairwise_connections'] - tw = m['total_weight'] - pct = round((connected_n / total_n) * 100) if total_n else 0 - avg = tw / pw if pw else 0 + for cid, b in buckets.items(): + pw = len(b['edges']) + cn = len(b['connected_n1']) + tw = b['weight_sum'] + pct = round((cn / total_n_queried) * 100) if total_n_queried else 0 + avg = (tw / pw) if pw else 0 + label = class_labels.get(cid, cid) rows.append({ - 'id': pid, - partner_key: f"[{m['label']}]({pid})" if pid else m['label'], - 'total_n': total_n, - 'connected_n': connected_n, + 'id': cid, + '_label': label, + 'total_n': total_n_queried, + 'connected_n': cn, 'percent_connected': pct, 'pairwise_connections': pw, 'total_weight': tw, @@ -3181,12 +3387,16 @@ def _merge_connectivity_rows(entries, partner_key, partner_id_key, partner_label return rows -def _num(v): - """Coerce a value to a number, defaulting to 0.""" - try: - return float(v) - except (TypeError, ValueError): - return 0 +def _format_class_connectivity_rows(rows, partner_key): + """Add the markdown-link partner column expected by callers and drop the + internal ``_label`` field.""" + out = [] + for r in rows: + formatted = dict(r) + label = formatted.pop('_label', formatted['id']) + formatted[partner_key] = f"[{label}]({formatted['id']})" + out.append(formatted) + return out @with_solr_cache('downstream_class_connectivity_query') @@ -3194,9 +3404,12 @@ def get_downstream_class_connectivity(short_form: str, return_dataframe=True, li """ Retrieves downstream connectivity classes for the specified neuron class. - Uses OWLERY to expand subclasses of the queried class, fetches the - downstream_connectivity_query Solr field for each, and merges results - by downstream partner class. + Uses OWLERY to expand subclasses of the queried class, fetches per-instance + connectivity from the Solr cache, and aggregates by partner class with + set-union semantics on partner instance memberships. The partner-side + hierarchy is walked up to ``NEURON_ROOT_SHORT_FORM`` so that connections to + a child class also count toward each ancestor class's row, without + double-counting under FBbt multi-inheritance. Matching criteria: Class + Neuron @@ -3205,20 +3418,13 @@ def get_downstream_class_connectivity(short_form: str, return_dataframe=True, li :param limit: maximum number of results to return (default -1, returns all results) :return: Downstream partner neuron classes with connectivity statistics """ - entries = _fetch_connectivity_entries(short_form, 'downstream_connectivity_query') - if not entries: + rows = _aggregate_class_connectivity(short_form, 'downstream') + if not rows: if return_dataframe: return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _merge_connectivity_rows( - entries, - partner_key='downstream_class', - partner_id_key='downstream_class_id', - partner_label_key='downstream_class', - ) - - # Sort by pairwise_connections descending + rows = _format_class_connectivity_rows(rows, partner_key='downstream_class') rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) total_count = len(rows) @@ -3248,9 +3454,12 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi """ Retrieves upstream connectivity classes for the specified neuron class. - Uses OWLERY to expand subclasses of the queried class, fetches the - upstream_connectivity_query Solr field for each, and merges results - by upstream partner class. + Uses OWLERY to expand subclasses of the queried class, fetches per-instance + connectivity from the Solr cache, and aggregates by partner class with + set-union semantics on partner instance memberships. The partner-side + hierarchy is walked up to ``NEURON_ROOT_SHORT_FORM`` so that connections + from a child class also count toward each ancestor class's row, without + double-counting under FBbt multi-inheritance. Matching criteria: Class + Neuron @@ -3259,20 +3468,13 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi :param limit: maximum number of results to return (default -1, returns all results) :return: Upstream partner neuron classes with connectivity statistics """ - entries = _fetch_connectivity_entries(short_form, 'upstream_connectivity_query') - if not entries: + rows = _aggregate_class_connectivity(short_form, 'upstream') + if not rows: if return_dataframe: return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _merge_connectivity_rows( - entries, - partner_key='upstream_class', - partner_id_key='upstream_class_id', - partner_label_key='upstream_class', - ) - - # Sort by pairwise_connections descending + rows = _format_class_connectivity_rows(rows, partner_key='upstream_class') rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) total_count = len(rows)