Conversation
|
Here are some instructions to test the Multi-GPU Batched KMeans API with RAFT comms (to be used with Ray/Dask) : RAFT comms (Ray/Dask) demo codeCompilation commandLaunch command
|
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ |
There was a problem hiding this comment.
Can we get rid of this file entirely and combine with regular mg kmeans (just as we are doing in PR #2015)? Is that possible?
There was a problem hiding this comment.
Also, MNMG should be able to reuse the snmg_fit function (for a single worker) as is, right? Except that the nccl reduce macro will be replaced by something like comms.allreduce()
There was a problem hiding this comment.
I refactored the code to work flawlessly with the single GPU refactor. The process_batch function is now being used. The weight normalization is improved (rel_tol and zero/invalid check). The inertia_check field is not used anymore.
However, I do not feel confident implementing a massive unifying refactor in this PR that is originally dedicated to introducing MG Batched KMeans. Would it be fine to leave things as is for now? We could come back to it in a dedicated follow-up PR.
|
Referencing this comment here: #2015 (comment)
|
934842d to
a461fe9
Compare
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughSummary by CodeRabbit
WalkthroughAdds a new multi‑GPU batched K‑Means implementation and OpenMP wrapper, refactors k‑means internals (weight handling, batch processing, centroid norm caching), introduces runtime RAFT handle dispatch for MG/MNMG vs single‑GPU, updates distance/cost APIs, and adds MG batched tests and CMake registration. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/src/cluster/detail/kmeans_mg.cuh (1)
474-499:⚠️ Potential issue | 🟠 MajorNormalize weights against the global sample count, not the local shard size.
After the allreduce,
wt_sumis global, buttargetis still this rank'sn_samples. That gives each rank a different scale factor and breaks equivalence with single-GPU weighted k-means whenever partitions are uneven or weights vary by rank. Reduce the sample counts as well and useglobal_n_samples / global_wt_sumhere.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/cluster/detail/kmeans_mg.cuh` around lines 474 - 499, The code normalizes weights using a per-rank n_samples (local target) instead of the global sample count, causing inconsistent scaling after comm.allreduce of wt_aggr; update the logic to also reduce the local sample count to a global count (e.g., perform an allreduce on n_samples into a global_n_samples variable after or alongside raft::linalg::mapThenSumReduce/comm.allreduce), then set target = static_cast<DataT>(global_n_samples) and compute the scale factor as target / wt_sum before calling raft::linalg::map to rescale weight; modify references to target and ensure synchronization (raft::resource::sync_stream) remains correct.
🧹 Nitpick comments (1)
cpp/tests/cluster/kmeans_mg_batched.cu (1)
121-132: Add a zero-total-weight regression case.The new MG path now rejects
wt_sum == 0, but this fixture only exercises positive-weight modes. A dedicated all-zero host-weight case would pin the new validation behavior and guard against future regressions.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/cluster/kmeans_mg_batched.cu` around lines 121 - 132, Add a regression case that produces all-zero host weights so the MG path rejection of wt_sum == 0 is exercised: inside the weight setup for the kmeans_mg_batched fixture (where testparams_.weight_mode, h_sample_weight, h_sw and n_samples are handled) add a branch for a new weight_mode value (e.g., weight_mode == 4) that sets every h_sample_weight[i] to T(0) and constructs h_sw via raft::make_host_vector_view<const T, int64_t>(h_sample_weight.data(), n_samples); ensure the new mode is included in the test parameterization so the test runs this all-zero-weight case.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/include/cuvs/cluster/kmeans.hpp`:
- Around line 170-175: Update the Doxygen doc block for the k-means public API
in cpp/include/cuvs/cluster/kmeans.hpp so it consistently documents the
RAFT-comms path (include "MPI-backed RAFT comms" alongside Dask/Ray) and mirrors
the detailed host-data/dispatch notes added to the float overload for the double
overload as well; specifically edit the comment around the kmeans function
overloads (the multi-GPU dispatch paragraph currently at lines ~170-175 and the
brief at ~206-207) to mention MPI-backed RAFT comms and copy the expanded
host-data behavior text from the float overload into the double overload's
Doxygen so both public overloads have identical, complete documentation.
In `@cpp/src/cluster/detail/kmeans_batched.cuh`:
- Around line 109-118: Guard against zero/invalid wt_sum before dividing: before
computing and returning target / wt_sum, check if wt_sum is <= 0 or not finite
(use std::isfinite) and handle it (e.g., RAFT_LOG_ERROR/RAFT_LOG_WARN that
weights are zero/invalid and return T{1} to skip scaling) so you never perform
target / wt_sum; keep the existing RAFT_LOG_DEBUG path and only do the division
after the validity check of wt_sum and the existing tolerance check on wt_sum vs
target.
In `@cpp/src/cluster/detail/kmeans_mg_batched.cuh`:
- Around line 183-207: Replace the exact-equality check and add a
positive-weight validation: after computing global_n and global_wt (from
d_n_local/d_wt), compute a tolerance tol = static_cast<T>(global_n) *
std::numeric_limits<T>::epsilon(); use if (std::abs(global_wt -
static_cast<T>(global_n)) > tol) to decide scaling instead of global_wt !=
static_cast<T>(global_n); before computing weight_scale assert/validate that
global_wt > T{0} (e.g. RAFT_EXPECTS or the project's preferred check) so
zero/negative total weights fail early; update references to weight_scale,
sample_weight, SNMG_ALLREDUCE, d_wt, global_wt and global_n accordingly.
- Around line 338-346: The call to
cuvs::cluster::kmeans::detail::accumulate_batch_centroids<T, IdxT> is passing
weight_per_cluster where the callee expects cluster_counts (sample counts per
cluster), causing wrong centroid accumulation; fix by supplying a proper
cluster_counts view (e.g., create or compute a cluster_counts device_vector/view
separate from weight_per_cluster) and pass that as the 6th argument instead of
weight_per_cluster, or if weighted k-means is intended, update the
accumulate_batch_centroids signature and its internal use to accept and handle
weight_per_cluster consistently; adjust callers accordingly (referenced symbols:
accumulate_batch_centroids, weight_per_cluster, batch_counts, centroid_sums).
In `@cpp/src/cluster/detail/kmeans_mg.cuh`:
- Around line 655-672: Gate the inertia-based stopping test behind
params.inertia_check and fix the comparison to measure relative improvement:
only run the check when params.inertia_check is true and n_iter[0] > 1 and
priorClusteringCost > DataT{0}; if curClusteringCost < priorClusteringCost
compute DataT rel = (priorClusteringCost - curClusteringCost) /
priorClusteringCost and set done = true when rel <= params.tol; ensure that cost
increases (curClusteringCost >= priorClusteringCost) do not count as convergence
(leave done false), and still update priorClusteringCost = curClusteringCost
afterwards.
---
Outside diff comments:
In `@cpp/src/cluster/detail/kmeans_mg.cuh`:
- Around line 474-499: The code normalizes weights using a per-rank n_samples
(local target) instead of the global sample count, causing inconsistent scaling
after comm.allreduce of wt_aggr; update the logic to also reduce the local
sample count to a global count (e.g., perform an allreduce on n_samples into a
global_n_samples variable after or alongside
raft::linalg::mapThenSumReduce/comm.allreduce), then set target =
static_cast<DataT>(global_n_samples) and compute the scale factor as target /
wt_sum before calling raft::linalg::map to rescale weight; modify references to
target and ensure synchronization (raft::resource::sync_stream) remains correct.
---
Nitpick comments:
In `@cpp/tests/cluster/kmeans_mg_batched.cu`:
- Around line 121-132: Add a regression case that produces all-zero host weights
so the MG path rejection of wt_sum == 0 is exercised: inside the weight setup
for the kmeans_mg_batched fixture (where testparams_.weight_mode,
h_sample_weight, h_sw and n_samples are handled) add a branch for a new
weight_mode value (e.g., weight_mode == 4) that sets every h_sample_weight[i] to
T(0) and constructs h_sw via raft::make_host_vector_view<const T,
int64_t>(h_sample_weight.data(), n_samples); ensure the new mode is included in
the test parameterization so the test runs this all-zero-weight case.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 715dc4bb-4fb5-4063-b0f2-203992e27484
📒 Files selected for processing (11)
cpp/include/cuvs/cluster/kmeans.hppcpp/src/cluster/detail/kmeans_batched.cuhcpp/src/cluster/detail/kmeans_common.cuhcpp/src/cluster/detail/kmeans_mg.cuhcpp/src/cluster/detail/kmeans_mg_batched.cuhcpp/src/cluster/detail/minClusterDistanceCompute.cucpp/src/cluster/kmeans.cuhcpp/src/cluster/kmeans_fit_double.cucpp/src/cluster/kmeans_fit_float.cucpp/tests/CMakeLists.txtcpp/tests/cluster/kmeans_mg_batched.cu
a461fe9 to
0a6748d
Compare
Closes #1989.
Adds multi-GPU support to KMeans fit for host-resident data, with two modes:
device_resources_snmg.Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present.