Skip to content

Reuse minClusterAndDistance Helper for Balanced KMeans#2001

Open
tarang-jain wants to merge 29 commits intorapidsai:mainfrom
tarang-jain:hierarchical-helpers
Open

Reuse minClusterAndDistance Helper for Balanced KMeans#2001
tarang-jain wants to merge 29 commits intorapidsai:mainfrom
tarang-jain:hierarchical-helpers

Conversation

@tarang-jain
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain commented Apr 8, 2026

  • The norm computation + fused reduction is already present in the minClusterDistanceCompute function. We can reuse that for balanced kmeans.

  • Furthermore, this PR updates the minClusterDistanceCompute function to also use the fused kernel for the cosine metric.

  • Skip redundant centroid-norm computation for cosine k-means: kmeans_balanced::fit now L2-normalizes the returned centroids for CosineExpanded, letting minClusterAndDistanceCompute / minClusterDistanceCompute skip the per-call centroid-norm reduction (fills the norms buffer with 1s instead). Also drops the now-redundant manual row_normalize of cluster centers in ivf_pq_build.

Binary size savings (conda-cpp-build check with CUDA 12.9.1 + x86_64):
main: 660.72 MB
This PR: 648.83 MB
(conda-cpp-build check with CUDA 13.1.1 + amd):
main: 305.70 B
This PR: 300.63 MB

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 8, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@aamijar aamijar moved this to In Progress in Unstructured Data Processing Apr 8, 2026
@aamijar aamijar added non-breaking Introduces a non-breaking change improvement Improves an existing functionality labels Apr 8, 2026
@tarang-jain tarang-jain marked this pull request as ready for review April 9, 2026 00:37
@tarang-jain tarang-jain requested a review from a team as a code owner April 9, 2026 00:37
Copy link
Copy Markdown
Member

@aamijar aamijar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor Tarang! Could you add a description to the PR?

Copy link
Copy Markdown
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tarang-jain ! Suggesting small changes and adding a question:

Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu Outdated
Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu Outdated
Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu Outdated
Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu
@tarang-jain tarang-jain requested a review from jinsolp April 10, 2026 21:14
Comment on lines 194 to 197
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle,
raft::make_device_matrix_view<const DataT, IndexT>(
centroids.data_handle(), centroids.extent(0), centroids.extent(1)),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be computing norms like this here too?

if (metric == cuvs::distance::DistanceType::CosineExpanded) {
      raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
        handle, centroids, centroidsNorm, raft::sqrt_op{});
    } else {
      raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
        handle, centroids, centroidsNorm);
    }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have fixed this.

n_clusters,
n_features,
(void*)workspace.data(),
metric != cuvs::distance::DistanceType::L2Expanded,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the deleted code in kmeans_balanced.cuh, this used to be false for the CosineExpended metric. However, this condition passes true for the CosineExpended metric and sqrt-s the distances output.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is because Cosine is now supported in our fused kernel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't being caught because we were lacking tests for the cosine metric. We always test L2 (default). I am trying to add cosine inputs to the tests next.

@achirkin
Copy link
Copy Markdown
Contributor

Could you please refresh my understanding on couple math questions?

  1. By the time we call minClusterDistanceCompute, the data is normalized, so we can use the same L2 norm everywhere, right?
  2. Why is sqrt needed for the cosine case?

@tarang-jain
Copy link
Copy Markdown
Contributor Author

tarang-jain commented Apr 14, 2026

By the time we call minClusterDistanceCompute, the data is normalized, so we can use the same L2 norm everywhere, right?

Cosine is only supported in balanced kmeans. When we call minClusterAndDistance, the data is not normalized, but the centroids are.

Why is sqrt needed for the cosine case?

sqrt is needed because the cosine distance op directly divides by the norms (it does not do the sqrt):

return static_cast<AccT>(1.0) - static_cast<AccT>(accVal / (aNorm * bNorm));

Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu Outdated
Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu Outdated
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Summary by CodeRabbit

  • Refactor

    • Improved k-means prediction and distance computation for expanded distance metrics (L2-Expanded, L2-Sqrt-Expanded, Cosine-Expanded) for better performance and consistent results.
    • Simplified center normalization handling in IVF-PQ index building.
  • Tests

    • Added test coverage for k-means balanced clustering using the Cosine-Expanded metric.
  • Chores

    • Removed an unused internal include.

Walkthrough

Refactors fused-distance handling for expanded-L2 and CosineExpanded metrics by centralizing min-cluster-and-distance computation, removes per-metric workspace and centroid normalization in predict paths, eliminates cosine special-case normalization in IVF-PQ build, and adds CosineExpanded test inputs for balanced k-means.

Changes

Cohort / File(s) Summary
K-means predict & fused reducer
cpp/src/cluster/detail/kmeans_balanced.cuh, cpp/src/cluster/detail/minClusterDistanceCompute.cu
Reworked predict logic to stop using per-metric fusedDistanceNNMinReduce setup; preallocate RMM buffers and matrix/vector views and call centralized minClusterAndDistanceCompute. Treat CosineExpanded as fused, compute centroid norms (sqrt for cosine) before fused reduction, and simplify non-fused tiled path.
Header cleanup
cpp/src/cluster/detail/kmeans_common.cuh
Removed an unused include (fused_distance_nn.cuh).
IVF-PQ build
cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
Removed special-case row-normalization of k-means centers for CosineExpanded; centers are passed un-normalized into cuvs::cluster::kmeans::predict.
Tests — kmeans balanced
cpp/tests/cluster/kmeans_balanced.cu
Added get_kmeans_balanced_cosine_inputs<MathT,IdxT>(), introduced inputsf_cosine_i32, and added a KB test instantiation to exercise CosineExpanded distance.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Reuse minClusterAndDistance Helper for Balanced KMeans' accurately describes the main objective of the PR, which is to refactor balanced KMeans to reuse the minClusterAndDistance helper function.
Description check ✅ Passed The description clearly explains the rationale (reusing existing norm computation and fused reduction), the updates made (supporting cosine metric in minClusterDistanceCompute), and includes quantified impact (binary size savings), directly addressing the changes in the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/src/cluster/detail/kmeans.cuh (1)

531-554: ⚠️ Potential issue | 🟠 Major

Preserve the cosine metric when reclustering KMeans|| candidates.

This function now prepares cosine norms and uses params.metric for candidate sampling, but the weighted reclustering step builds default_params and only copies n_clusters. For CosineExpanded with oversampling_factor > 0, that can switch the final candidate reclustering back to the default metric.

🐛 Proposed fix
-    cuvs::cluster::kmeans::params default_params;
+    cuvs::cluster::kmeans::params default_params = params;
     default_params.n_clusters = params.n_clusters;

Also applies to: 660-670

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/kmeans.cuh` around lines 531 - 554, The reclustering
step creates a local default_params and only copies n_clusters, which causes the
metric to fall back to the default (breaking CosineExpanded when
oversampling_factor>0); update the code path that builds default_params (the
params struct used for the weighted reclustering step after
minClusterDistanceCompute and before kmeans::detail::weightedRecluster or the
candidate recluster call) to also copy params.metric (and any metric-related
buffers like L2NormX/L2NormBuf_OR_DistBuf usage flags if present) so the
reclustering preserves cuvs::distance::DistanceType::CosineExpanded behavior.
🧹 Nitpick comments (2)
cpp/src/cluster/detail/kmeans_balanced.cuh (1)

184-191: Include CosineExpanded in the fused-memory minibatch estimate.

predict_core now routes cosine through the fused helper, but calc_minibatch_size still falls through to the distance-matrix estimate for cosine. That can force unnecessarily small minibatches for large n_clusters.

♻️ Proposed update
-    // fusedL2NN needs a mutex and a key-value pair for each row.
+    // Fused NN reduction needs a mutex and a key-value pair for each row.
     case distance::DistanceType::L2Expanded:
-    case distance::DistanceType::L2SqrtExpanded: {
+    case distance::DistanceType::L2SqrtExpanded:
+    case distance::DistanceType::CosineExpanded: {
       mem_per_row += sizeof(int);
       mem_per_row += sizeof(raft::KeyValuePair<IdxT, MathT>);
     } break;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/kmeans_balanced.cuh` around lines 184 - 191, The
minibatch memory estimate in calc_minibatch_size treats cosine as a
distance-matrix path but predict_core uses the fused helper; update the switch
in calc_minibatch_size to include distance::DistanceType::CosineExpanded
alongside DistanceType::L2Expanded and DistanceType::L2SqrtExpanded so
CosineExpanded uses the fused-memory per-row accounting (the same branch that
adds sizeof(int) and sizeof(raft::KeyValuePair<IdxT, MathT>)), ensuring
minibatch sizes aren't underestimated for cosine.
cpp/tests/cluster/kmeans_balanced.cu (1)

250-252: Add cosine coverage for separate fit + predict.

The new final centroid normalization is specifically needed by downstream predict, but this cosine instantiation only exercises fit_predict. Add a SeparateFitPredict=true cosine case so regressions in returned centroid normalization are caught.

🧪 Proposed test coverage
 KB_TEST((KmeansBalancedTest<float, float, uint32_t, int, raft::identity_op, false>),
         KmeansBalancedTestCosineFFU32I32,
         inputsf_cosine_i32);
+KB_TEST((KmeansBalancedTest<float, float, int, int, raft::identity_op, true>),
+        KmeansBalancedTestCosineFFI32I32_SEP,
+        inputsf_cosine_i32);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tests/cluster/kmeans_balanced.cu` around lines 250 - 252, The cosine
instantiation currently only runs fit_predict; add a second KB_TEST that
instantiates KmeansBalancedTest with SeparateFitPredict=true (i.e.,
KmeansBalancedTest<..., raft::identity_op, true>) using the same
inputsf_cosine_i32 to exercise the separate fit + predict path and ensure the
final centroid normalization used by predict is validated; give the test a
distinct name (e.g., KmeansBalancedTestCosineFFU32I32_SeparateFitPredict) so it
runs alongside the existing fit_predict test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu`:
- Around line 44-50: The helper currently assumes centroids are unit length for
metric == cuvs::distance::DistanceType::CosineExpanded by filling centroidsNorm
with 1, which corrupts non-balanced k‑means paths; update
minClusterDistanceCompute (and minClusterAndDistanceCompute) to compute actual
L2 norms for centroids when CosineExpanded is requested (use
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle,
centroids, centroidsNorm)) instead of filling ones, or alternatively add and
thread an explicit boolean flag (e.g., centroids_are_normalized) from balanced
callers and only skip norm computation when that flag is true—prefer the safer
default of computing norms in these functions.

---

Outside diff comments:
In `@cpp/src/cluster/detail/kmeans.cuh`:
- Around line 531-554: The reclustering step creates a local default_params and
only copies n_clusters, which causes the metric to fall back to the default
(breaking CosineExpanded when oversampling_factor>0); update the code path that
builds default_params (the params struct used for the weighted reclustering step
after minClusterDistanceCompute and before kmeans::detail::weightedRecluster or
the candidate recluster call) to also copy params.metric (and any metric-related
buffers like L2NormX/L2NormBuf_OR_DistBuf usage flags if present) so the
reclustering preserves cuvs::distance::DistanceType::CosineExpanded behavior.

---

Nitpick comments:
In `@cpp/src/cluster/detail/kmeans_balanced.cuh`:
- Around line 184-191: The minibatch memory estimate in calc_minibatch_size
treats cosine as a distance-matrix path but predict_core uses the fused helper;
update the switch in calc_minibatch_size to include
distance::DistanceType::CosineExpanded alongside DistanceType::L2Expanded and
DistanceType::L2SqrtExpanded so CosineExpanded uses the fused-memory per-row
accounting (the same branch that adds sizeof(int) and
sizeof(raft::KeyValuePair<IdxT, MathT>)), ensuring minibatch sizes aren't
underestimated for cosine.

In `@cpp/tests/cluster/kmeans_balanced.cu`:
- Around line 250-252: The cosine instantiation currently only runs fit_predict;
add a second KB_TEST that instantiates KmeansBalancedTest with
SeparateFitPredict=true (i.e., KmeansBalancedTest<..., raft::identity_op, true>)
using the same inputsf_cosine_i32 to exercise the separate fit + predict path
and ensure the final centroid normalization used by predict is validated; give
the test a distinct name (e.g.,
KmeansBalancedTestCosineFFU32I32_SeparateFitPredict) so it runs alongside the
existing fit_predict test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 8e18f9fc-8562-4db1-aa4a-563e9a606fc0

📥 Commits

Reviewing files that changed from the base of the PR and between f2bffb6 and c605b97.

📒 Files selected for processing (6)
  • cpp/src/cluster/detail/kmeans.cuh
  • cpp/src/cluster/detail/kmeans_balanced.cuh
  • cpp/src/cluster/detail/kmeans_common.cuh
  • cpp/src/cluster/detail/minClusterDistanceCompute.cu
  • cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh
  • cpp/tests/cluster/kmeans_balanced.cu
💤 Files with no reviewable changes (2)
  • cpp/src/cluster/detail/kmeans_common.cuh
  • cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh

Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu
Copy link
Copy Markdown
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tarang-jain !

Copy link
Copy Markdown
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor is really nice, just had a few comments

Comment on lines +44 to +49
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
// Centroids are L2-normalized for cosine metric
raft::matrix::fill(handle, centroidsNorm, DataT{1});
} else {
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(
handle, centroids, centroidsNorm);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the coderabbit comment indeed was wrong, non-balanced kmeans::fit isn't actually a cosine caller and this PR doesn't change that.

One thing I wasn't sure this resolves, though: what about the public kmeans_balanced::predict (kmeans_balanced.cuh:130)? If a user calls it with their own centroids under CosineExpanded, which seems reasonable since cosine is scale-invariant, would they now silently get wrong labels, given the fused helper divides by yn=1 instead of ‖c‖? I don't see the unit-norm precondition documented on the signature or in /include/.../kmeans.hpp, so I'm wondering whether it's worth either threading a centroids_are_unit_norm flag (default false, balanced inner loop passes true) or normalizing into a scratch buffer at the public entry. Would either fit here?

Comment thread cpp/src/cluster/detail/minClusterDistanceCompute.cu
// is the distance between the sample and the 'centroid[key]'
// is the distance between the sample and the 'centroids[key]'.
//
// NB: (CosineExpanded): `centroids` rows must be L2-normalized when the cosine metric is used.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments are only visible when reading the implementation. Other translation units and Doxygen see only kmeans_common.cuh, no? Which has no mention of the cosine unit-norm precondition. It'd be a good idea duplicate the NB onto both declarations here so it is visible at the call site.

Comment on lines +250 to +252
KB_TEST((KmeansBalancedTest<float, float, uint32_t, int, raft::identity_op, false>),
KmeansBalancedTestCosineFFU32I32,
inputsf_cosine_i32);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick one on the test coverage, the new cosine case uses SeparateFitPredict = false, so it only exercises fit_predict. That path keeps using the same centers fit just finished normalizing, so it won't catch it if the post-loop row_normalize in balancing_em_iters ever gets lost in a refactor (which is the whole invariant this PR is leaning on now that ivf_pq_build dropped its manual normalize). Would it be easy to also add a SeparateFitPredict = true variant? Something like

KB_TEST((KmeansBalancedTest<float, float, uint32_t, int, raft::identity_op, true>),
        KmeansBalancedTestCosineSeparateFFU32I32,
        inputsf_cosine_i32);

Comment on lines -1332 to -1334
if (impl->metric() == distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize<raft::linalg::L2Norm>(handle, centers_const_view, centers_view);
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this normalize is correct today because kmeans_balanced::fit now guarantees unit-norm centers for CosineExpanded, right?.

But there's no comment or assertion tying these two files together, if balancing_em_iters's post-loop normalize ever gets lost in a refactor, IVF-PQ silently regresses recall with no failing test anywhere. It would be a good idea to leave a short comment at both sites (here and in balancing_em_iters) noting the dependency

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! one comment

@@ -5,7 +5,6 @@

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please edit the docstrings of predict overloads taking a balanced_params parameters in cpp/include/cuvs/cluster/kmeans.hpp to clearly state that the centroids should be normalized when using the cosine metric. A user may have trained centroids elsewhere and may be attempting to GPU-accelerate the prediction.

@tarang-jain
Copy link
Copy Markdown
Contributor Author

@dantegd @viclafargue I got rid of the unit norm assumption and restored the previous version. We can revisit the unit norm assumption later, because this will make the user experience difficult.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/src/cluster/detail/minClusterDistanceCompute.cu (1)

13-27: ⚠️ Potential issue | 🟡 Minor

Document the L2NormX contract for CosineExpanded.

The header comment describes the return value but does not state the key precondition that this function now silently relies on: for CosineExpanded, the caller must pass L2NormX as the elementwise L2 norm of X (i.e. ||x_i||, i.e. sqrt already applied), because the fused cosine op divides by the supplied norms without taking a square root. The same contract applies to minClusterDistanceCompute. Since these are the only two entry points for balanced / non-balanced k‑means into the cosine fused path, a short NB here makes it much harder for a future caller to wire cosine in with L2Norm(X) (squared) and silently ship wrong distances.

📝 Suggested doc note
 // Calculates a <key, value> pair for every sample in input 'X' where key is an
 // index to an sample in 'centroids' (index of the nearest centroid) and 'value'
 // is the distance between the sample and the 'centroids[key]'.
+// NB: for CosineExpanded, `L2NormX` must contain the elementwise L2 norm
+// (sqrt-applied) of the rows of `X`, because the fused cosine op divides by
+// the supplied norms without applying sqrt internally. `centroids` need not
+// be normalized; their norms are computed inside this function.
 template <typename DataT, typename IndexT>
 void minClusterAndDistanceCompute(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu` around lines 13 - 27,
The comment must document that for the cosine fused path the caller must pass
L2NormX as the elementwise L2 norm (||x_i|| — square root already applied)
because the fused cosine op (used when metric ==
cuvs::distance::DistanceType::CosineExpanded) divides by the supplied norms
without taking a square root; update the header comment for both
minClusterAndDistanceCompute and minClusterDistanceCompute to state this
precondition explicitly and warn not to pass squared norms (L2Norm(X) squared)
to avoid silently incorrect distances.
🧹 Nitpick comments (3)
cpp/src/cluster/detail/minClusterDistanceCompute.cu (3)

33-35: Hoist is_fused into a single helper.

The three-way metric check is duplicated verbatim between the two functions and will be easy to drift (e.g. when a fourth expanded metric is added). A tiny inline helper in kmeans_common.cuh keeps the truth in one place.

♻️ Suggested helper
// in kmeans_common.cuh
inline bool is_fused_distance(cuvs::distance::DistanceType metric) {
  return metric == cuvs::distance::DistanceType::L2Expanded ||
         metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
         metric == cuvs::distance::DistanceType::CosineExpanded;
}

Then replace both occurrences with bool is_fused = is_fused_distance(metric);.

Also applies to: 183-185

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu` around lines 33 - 35,
Create an inline helper is_fused_distance in kmeans_common.cuh that checks if a
given cuvs::distance::DistanceType is one of L2Expanded, L2SqrtExpanded, or
CosineExpanded, then replace the duplicate three-way checks (previously
assigning bool is_fused = metric == ... || ... || ...) in
minClusterDistanceCompute.cu with bool is_fused = is_fused_distance(metric); so
both occurrences use the single helper and avoid divergence when adding new
expanded metrics.

23-27: batch_samples / batch_centroids are now dead on the fused path.

Since is_fused additionally covers CosineExpanded, the tiled batching parameters are effectively ignored for all three expanded metrics and only apply to the non-fused branch. This is a silent behavior change for callers that rely on those knobs to bound transient memory (e.g. large n_samples × n_clusters problems previously saw per-tile dispatch). If this is intentional (because fusedDistanceNNMinReduce scales fine over the full dataset), consider at least documenting it on the signature. Otherwise, the batch bounds should be honored on the fused path too.

📝 Suggested doc note on the params
   rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
   cuvs::distance::DistanceType metric,
-  int batch_samples,
-  int batch_centroids,
+  int batch_samples,   // only used when metric is not fused (L2Expanded / L2SqrtExpanded / CosineExpanded)
+  int batch_centroids, // only used when metric is not fused
   rmm::device_uvector<char>& workspace)

Also applies to: 71-73

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu` around lines 23 - 27,
The tiled batching parameters batch_samples and batch_centroids are ignored when
is_fused is true (which also covers CosineExpanded), so callers lose the
per-tile memory bounding; either restore batching on the fused path or document
the behavior. Fix by updating the fused branch in minClusterDistanceCompute (the
branch that calls fusedDistanceNNMinReduce) to honor batch_samples and
batch_centroids by splitting the fused work into tiles analogous to the
non-fused branch, or if intentionally global, add a clear doc comment on the
function signature and above the is_fused/CosineExpanded handling stating that
batch_samples and batch_centroids are only applied to the non-fused path and
that fusedDistanceNNMinReduce operates over the full dataset. Ensure references
to is_fused, CosineExpanded, fusedDistanceNNMinReduce, and the parameters
batch_samples/batch_centroids are present so reviewers can locate the changes.

169-169: Make centroids parameter const in minClusterDistanceCompute to eliminate redundant view rewraps.

minClusterAndDistanceCompute already takes centroids as device_matrix_view<const DataT, IndexT> and the function never mutates centroids. The current non-const signature forces unnecessary rewraps at the raft::linalg::norm call sites (and in the non-fused path where centroidsView is created). Aligning the signature removes this boilerplate and matches the const pattern already established in the helper function.

♻️ Proposed change
 template <typename DataT, typename IndexT>
 void minClusterDistanceCompute(raft::resources const& handle,
                                raft::device_matrix_view<const DataT, IndexT> X,
-                               raft::device_matrix_view<DataT, IndexT> centroids,
+                               raft::device_matrix_view<const DataT, IndexT> centroids,
                                raft::device_vector_view<DataT, IndexT> minClusterDistance,

This change requires updating the explicit template instantiations (lines 275–286) and the corresponding declaration in kmeans_common.cuh (line 394).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu` at line 169, Change the
centroids parameter of minClusterDistanceCompute to a const view to avoid
redundant rewraps: update the function signature from
raft::device_matrix_view<DataT, IndexT> centroids to
raft::device_matrix_view<const DataT, IndexT> centroids in
minClusterDistanceCompute, then adjust the explicit template instantiations for
minClusterDistanceCompute (previously lines ~275–286) and the declaration in
kmeans_common.cuh (previously line ~394) to use the const DataT view; this
aligns minClusterDistanceCompute with minClusterAndDistanceCompute which already
accepts raft::device_matrix_view<const DataT, IndexT> and removes the extra
rewraps at raft::linalg::norm call sites and the non-fused path centroidsView
creation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu`:
- Around line 13-27: The comment must document that for the cosine fused path
the caller must pass L2NormX as the elementwise L2 norm (||x_i|| — square root
already applied) because the fused cosine op (used when metric ==
cuvs::distance::DistanceType::CosineExpanded) divides by the supplied norms
without taking a square root; update the header comment for both
minClusterAndDistanceCompute and minClusterDistanceCompute to state this
precondition explicitly and warn not to pass squared norms (L2Norm(X) squared)
to avoid silently incorrect distances.

---

Nitpick comments:
In `@cpp/src/cluster/detail/minClusterDistanceCompute.cu`:
- Around line 33-35: Create an inline helper is_fused_distance in
kmeans_common.cuh that checks if a given cuvs::distance::DistanceType is one of
L2Expanded, L2SqrtExpanded, or CosineExpanded, then replace the duplicate
three-way checks (previously assigning bool is_fused = metric == ... || ... ||
...) in minClusterDistanceCompute.cu with bool is_fused =
is_fused_distance(metric); so both occurrences use the single helper and avoid
divergence when adding new expanded metrics.
- Around line 23-27: The tiled batching parameters batch_samples and
batch_centroids are ignored when is_fused is true (which also covers
CosineExpanded), so callers lose the per-tile memory bounding; either restore
batching on the fused path or document the behavior. Fix by updating the fused
branch in minClusterDistanceCompute (the branch that calls
fusedDistanceNNMinReduce) to honor batch_samples and batch_centroids by
splitting the fused work into tiles analogous to the non-fused branch, or if
intentionally global, add a clear doc comment on the function signature and
above the is_fused/CosineExpanded handling stating that batch_samples and
batch_centroids are only applied to the non-fused path and that
fusedDistanceNNMinReduce operates over the full dataset. Ensure references to
is_fused, CosineExpanded, fusedDistanceNNMinReduce, and the parameters
batch_samples/batch_centroids are present so reviewers can locate the changes.
- Line 169: Change the centroids parameter of minClusterDistanceCompute to a
const view to avoid redundant rewraps: update the function signature from
raft::device_matrix_view<DataT, IndexT> centroids to
raft::device_matrix_view<const DataT, IndexT> centroids in
minClusterDistanceCompute, then adjust the explicit template instantiations for
minClusterDistanceCompute (previously lines ~275–286) and the declaration in
kmeans_common.cuh (previously line ~394) to use the const DataT view; this
aligns minClusterDistanceCompute with minClusterAndDistanceCompute which already
accepts raft::device_matrix_view<const DataT, IndexT> and removes the extra
rewraps at raft::linalg::norm call sites and the non-fused path centroidsView
creation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a180279f-14ed-4245-9d55-7603a29b89a2

📥 Commits

Reviewing files that changed from the base of the PR and between 4462178 and 940c76d.

📒 Files selected for processing (1)
  • cpp/src/cluster/detail/minClusterDistanceCompute.cu

@tarang-jain tarang-jain removed the request for review from achirkin April 24, 2026 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

6 participants