Skip to content

Multi-GPU Batched KMeans#2017

Open
viclafargue wants to merge 8 commits intorapidsai:mainfrom
viclafargue:mg-batched-kmeans
Open

Multi-GPU Batched KMeans#2017
viclafargue wants to merge 8 commits intorapidsai:mainfrom
viclafargue:mg-batched-kmeans

Conversation

@viclafargue
Copy link
Copy Markdown
Contributor

Closes #1989.

Adds multi-GPU support to KMeans fit for host-resident data, with two modes:

  • OpenMP (cuVS SNMG): A single process drives all local GPUs via OMP threads and raw NCCL. Activated automatically when the handle is a device_resources_snmg.
  • RAFT comms (Ray / Dask / MPI): Each rank is a separate process that calls fit with its own data shard and an initialized RAFT communicator. Coordination uses the RAFT comms.

Both modes share the same core Lloyd's loop, batched streaming of host data, NCCL/comms allreduce of centroid sums and counts, and synchronized convergence. Supports sample weights, n_init best-of-N restarts, KMeansPlusPlus initialization, and float/double. Falls back to single-GPU when neither multi-GPU resources nor comms are present.

@viclafargue viclafargue self-assigned this Apr 13, 2026
@viclafargue viclafargue requested review from a team as code owners April 13, 2026 14:34
@viclafargue viclafargue added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Apr 13, 2026
@viclafargue
Copy link
Copy Markdown
Contributor Author

Here are some instructions to test the Multi-GPU Batched KMeans API with RAFT comms (to be used with Ray/Dask) :

RAFT comms (Ray/Dask) demo code
#include <cuvs/cluster/kmeans.hpp>

#include <raft/comms/std_comms.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/comms.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#include <cuda_runtime.h>
#include <mpi.h>
#include <nccl.h>

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <random>
#include <vector>

#define CHECK_CUDA(call)                                                 \
  do {                                                                   \
    cudaError_t e = (call);                                              \
    if (e != cudaSuccess) {                                              \
      std::fprintf(stderr, "CUDA error %s @ %s:%d\n",                   \
                   cudaGetErrorString(e), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

#define CHECK_NCCL(call)                                                 \
  do {                                                                   \
    ncclResult_t r = (call);                                             \
    if (r != ncclSuccess) {                                              \
      std::fprintf(stderr, "NCCL error %s @ %s:%d\n",                   \
                   ncclGetErrorString(r), __FILE__, __LINE__);           \
      MPI_Abort(MPI_COMM_WORLD, 1);                                      \
    }                                                                    \
  } while (0)

int main(int argc, char** argv)
{
  MPI_Init(&argc, &argv);

  int rank, num_ranks;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &num_ranks);

  CHECK_CUDA(cudaSetDevice(rank));

  ncclUniqueId nccl_id;
  if (rank == 0) CHECK_NCCL(ncclGetUniqueId(&nccl_id));
  MPI_Bcast(&nccl_id, sizeof(nccl_id), MPI_BYTE, 0, MPI_COMM_WORLD);

  ncclComm_t nccl_comm;
  CHECK_NCCL(ncclCommInitRank(&nccl_comm, num_ranks, nccl_id, rank));

  raft::resources handle;
  raft::comms::build_comms_nccl_only(&handle, nccl_comm, num_ranks, rank);

  // --- Demo parameters ---
  constexpr int64_t n_samples       = 100'000;
  constexpr int64_t n_features      = 32;
  constexpr int     n_clusters      = 10;
  constexpr int64_t streaming_batch = 10'000;
  constexpr float   cluster_spread  = 1.0f;
  constexpr float   center_range    = 30.0f;

  if (rank == 0) {
    std::printf("=== Multi-GPU KMeans Demo (%d ranks) ===\n", num_ranks);
    std::printf("Samples: %ld | Features: %ld | k: %d | batch: %ld\n\n",
                long(n_samples), long(n_features), n_clusters, long(streaming_batch));
  }

  // Generate synthetic blobs with well-separated cluster centers
  std::vector<float> h_data(n_samples * n_features);
  std::vector<int>   h_true_labels(n_samples);
  std::vector<float> cluster_centers(n_clusters * n_features);
  {
    std::mt19937 gen(12345);
    std::uniform_real_distribution<float> center_dist(-center_range, center_range);
    std::normal_distribution<float> noise(0.0f, cluster_spread);

    for (int c = 0; c < n_clusters; ++c)
      for (int d = 0; d < n_features; ++d)
        cluster_centers[c * n_features + d] = center_dist(gen);

    for (int64_t i = 0; i < n_samples; ++i) {
      int label = static_cast<int>(i % n_clusters);
      h_true_labels[i] = label;
      for (int d = 0; d < n_features; ++d)
        h_data[i * n_features + d] = cluster_centers[label * n_features + d] + noise(gen);
    }

    // Shuffle so labels aren't just sequential runs
    std::vector<int64_t> perm(n_samples);
    std::iota(perm.begin(), perm.end(), 0);
    std::shuffle(perm.begin(), perm.end(), gen);

    std::vector<float> tmp_data(h_data);
    std::vector<int>   tmp_labels(h_true_labels);
    for (int64_t i = 0; i < n_samples; ++i) {
      std::memcpy(h_data.data() + i * n_features,
                  tmp_data.data() + perm[i] * n_features,
                  n_features * sizeof(float));
      h_true_labels[i] = tmp_labels[perm[i]];
    }
  }

  int64_t base    = n_samples / num_ranks;
  int64_t rem     = n_samples % num_ranks;
  int64_t offset  = rank * base + std::min<int64_t>(rank, rem);
  int64_t n_local = base + (rank < rem ? 1 : 0);

  std::printf("[rank %d / GPU %d]  rows [%ld .. %ld)  (%ld samples)\n",
              rank, rank, long(offset), long(offset + n_local), long(n_local));

  auto X_local = raft::make_host_matrix_view<const float, int64_t>(
    h_data.data() + offset * n_features, n_local, n_features);

  auto d_centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);

  cuvs::cluster::kmeans::params params;
  params.n_clusters           = n_clusters;
  params.max_iter             = 50;
  params.tol                  = 1e-4;
  params.init                 = cuvs::cluster::kmeans::params::KMeansPlusPlus;
  params.rng_state.seed       = 42;
  params.inertia_check        = true;
  params.streaming_batch_size = streaming_batch;

  float   inertia = 0.0f;
  int64_t n_iter  = 0;

  cuvs::cluster::kmeans::fit(handle,
                             params,
                             X_local,
                             std::nullopt,
                             d_centroids.view(),
                             raft::make_host_scalar_view(&inertia),
                             raft::make_host_scalar_view(&n_iter));

  auto stream = raft::resource::get_cuda_stream(handle);
  CHECK_CUDA(cudaStreamSynchronize(stream));

  if (rank == 0) {
    // --- Predict labels on the full dataset (on rank 0) ---
    auto d_X = raft::make_device_matrix<float, int64_t>(handle, n_samples, n_features);
    CHECK_CUDA(cudaMemcpy(d_X.data_handle(), h_data.data(),
                          sizeof(float) * n_samples * n_features, cudaMemcpyHostToDevice));

    auto d_labels = raft::make_device_vector<int64_t, int64_t>(handle, n_samples);
    float predict_inertia = 0.0f;

    cuvs::cluster::kmeans::predict(
      handle, params,
      raft::make_device_matrix_view<const float, int64_t>(d_X.data_handle(), n_samples, n_features),
      std::nullopt,
      raft::make_device_matrix_view<const float, int64_t>(
        d_centroids.data_handle(), n_clusters, n_features),
      d_labels.view(),
      false,
      raft::make_host_scalar_view(&predict_inertia));
    CHECK_CUDA(cudaStreamSynchronize(stream));

    std::vector<int64_t> h_labels(n_samples);
    CHECK_CUDA(cudaMemcpy(h_labels.data(), d_labels.data_handle(),
                          sizeof(int64_t) * n_samples, cudaMemcpyDeviceToHost));

    // --- Quality: permutation-invariant accuracy via majority voting ---
    // For each predicted cluster, find which true label appears most often.
    std::vector<std::vector<int64_t>> confusion(n_clusters, std::vector<int64_t>(n_clusters, 0));
    for (int64_t i = 0; i < n_samples; ++i)
      confusion[h_labels[i]][h_true_labels[i]]++;

    // Greedy matching: assign each predicted cluster to its dominant true label
    std::vector<int> pred_to_true(n_clusters, -1);
    std::vector<bool> true_taken(n_clusters, false);
    for (int round = 0; round < n_clusters; ++round) {
      int64_t best_count = -1;
      int best_pred = -1, best_true = -1;
      for (int p = 0; p < n_clusters; ++p) {
        if (pred_to_true[p] >= 0) continue;
        for (int t = 0; t < n_clusters; ++t) {
          if (true_taken[t]) continue;
          if (confusion[p][t] > best_count) {
            best_count = confusion[p][t];
            best_pred = p;
            best_true = t;
          }
        }
      }
      pred_to_true[best_pred] = best_true;
      true_taken[best_true] = true;
    }

    int64_t correct = 0;
    std::vector<int64_t> cluster_sizes(n_clusters, 0);
    std::vector<int64_t> cluster_correct(n_clusters, 0);
    for (int64_t i = 0; i < n_samples; ++i) {
      int p = static_cast<int>(h_labels[i]);
      cluster_sizes[p]++;
      if (h_true_labels[i] == pred_to_true[p]) {
        ++correct;
        ++cluster_correct[p];
      }
    }
    double accuracy = 100.0 * correct / n_samples;

    // --- Compute centroid-to-true-center distances ---
    std::vector<float> h_centroids(n_clusters * n_features);
    CHECK_CUDA(cudaMemcpy(h_centroids.data(), d_centroids.data_handle(),
                          sizeof(float) * n_clusters * n_features, cudaMemcpyDeviceToHost));

    std::printf("\n============ Multi-GPU KMeans Results ============\n");
    std::printf("  Ranks             : %d\n", num_ranks);
    std::printf("  Total samples     : %ld\n", long(n_samples));
    std::printf("  Features          : %ld\n", long(n_features));
    std::printf("  Clusters (k)      : %d\n", n_clusters);
    std::printf("  Streaming batch   : %ld\n", long(streaming_batch));
    std::printf("  Lloyd iterations  : %ld\n", long(n_iter));
    std::printf("  Final inertia     : %.6f\n", double(inertia));
    std::printf("  Predict inertia   : %.6f\n", double(predict_inertia));
    std::printf("\n  --- Clustering Quality ---\n");
    std::printf("  Overall accuracy  : %.2f%% (%ld / %ld)\n",
                accuracy, long(correct), long(n_samples));

    std::printf("\n  Per-cluster breakdown:\n");
    std::printf("  %6s  %10s  %10s  %8s  %12s\n",
                "Pred", "TrueLabel", "Size", "Acc%", "CentroidErr");
    for (int p = 0; p < n_clusters; ++p) {
      int t = pred_to_true[p];
      double pct = cluster_sizes[p] > 0
                     ? 100.0 * cluster_correct[p] / cluster_sizes[p]
                     : 0.0;

      // L2 distance between learned centroid and ground truth center
      double dist2 = 0.0;
      for (int d = 0; d < n_features; ++d) {
        double diff = h_centroids[p * n_features + d] - cluster_centers[t * n_features + d];
        dist2 += diff * diff;
      }
      std::printf("  %6d  %10d  %10ld  %7.2f%%  %12.4f\n",
                  p, t, long(cluster_sizes[p]), pct, std::sqrt(dist2));
    }

    std::printf("\n  Expected accuracy for well-separated blobs: >99%%\n");
    if (accuracy >= 99.0)
      std::printf("  PASS: Clustering quality is high.\n");
    else if (accuracy >= 90.0)
      std::printf("  WARN: Clustering quality is acceptable but not ideal.\n");
    else
      std::printf("  FAIL: Clustering quality is poor!\n");

    std::printf("==================================================\n");
  }

  CHECK_NCCL(ncclCommDestroy(nccl_comm));
  MPI_Finalize();
  return 0;
}
Compilation command
nvcc -std=c++17 -x cu --extended-lambda -arch=native       \
 -I$CONDA_PREFIX/include/rapids                            \
 -I$CONDA_PREFIX/include                                   \
 demo_mg_kmeans_raft_comms.cu                              \
 -L$CONDA_PREFIX/lib -lcuvs -lnccl -lrmm -lmpi             \
 -lucxx -lucp -lucs                                       \
 -Xlinker=-rpath,$CONDA_PREFIX/lib                         \
 -o demo_mg_kmeans
Launch command

mpirun -np 2 ./demo_mg_kmeans

@viclafargue viclafargue requested a review from tarang-jain April 13, 2026 14:42
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh Outdated
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of this file entirely and combine with regular mg kmeans (just as we are doing in PR #2015)? Is that possible?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, MNMG should be able to reuse the snmg_fit function (for a single worker) as is, right? Except that the nccl reduce macro will be replaced by something like comms.allreduce()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code to work flawlessly with the single GPU refactor. The process_batch function is now being used. The weight normalization is improved (rel_tol and zero/invalid check). The inertia_check field is not used anymore.

However, I do not feel confident implementing a massive unifying refactor in this PR that is originally dedicated to introducing MG Batched KMeans. Would it be fine to leave things as is for now? We could come back to it in a dedicated follow-up PR.

@tarang-jain
Copy link
Copy Markdown
Contributor

tarang-jain commented Apr 22, 2026

Referencing this comment here: #2015 (comment)
@viclafargue please ensure to incorporate the inertia related changes:

  1. rel_tol with the batched weight addition
  2. Removal of the inertia check parameter
  3. wt_sum = 0 check if all-zero sample weights are passed. Check if you can reuse the weightSum helper from [Cleanup] Combine Batched and Regular KMeans Impl #2015.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: d3c18dc8-572e-4418-bdbc-a7fe4d7d095b

📥 Commits

Reviewing files that changed from the base of the PR and between 0a6748d and 7055272.

📒 Files selected for processing (3)
  • cpp/src/cluster/detail/kmeans_mg_batched.cuh
  • cpp/src/cluster/kmeans_fit_double.cu
  • cpp/src/cluster/kmeans_fit_float.cu
🚧 Files skipped from review as they are similar to previous changes (2)
  • cpp/src/cluster/detail/kmeans_mg_batched.cuh
  • cpp/src/cluster/kmeans_fit_double.cu

📝 Walkthrough

Summary by CodeRabbit

  • New Features

    • Added multi-GPU batched k-means with automatic runtime dispatch (SNMG, initialized communicator, or single-GPU fallback).
  • API Changes

    • Split cluster cost API into device- and host-targeted overloads for clearer semantics.
  • Documentation

    • Clarified k-means docs to describe automatic multi‑GPU dispatch behavior and cases.
  • Bug Fixes / Improvements

    • Improved weight validation, per-iteration cost reporting, and convergence handling.
  • Tests

    • Added end-to-end multi‑GPU batched k-means tests.

Walkthrough

Adds a new multi‑GPU batched K‑Means implementation and OpenMP wrapper, refactors k‑means internals (weight handling, batch processing, centroid norm caching), introduces runtime RAFT handle dispatch for MG/MNMG vs single‑GPU, updates distance/cost APIs, and adds MG batched tests and CMake registration.

Changes

Cohort / File(s) Summary
Public API docs
cpp/include/cuvs/cluster/kmeans.hpp
Doxygen for host-data batched fit updated to document automatic multi‑GPU dispatch cases and adjusted handle annotation.
MG Batched implementation
cpp/src/cluster/detail/kmeans_mg_batched.cuh
New SNMG/MNMG batched K‑Means core and OpenMP wrapper: NCCL/RAFT collectives, streaming batch loop, per-iter centroid accumulators, weight normalization, init/broadcast, grouped allreduce, and convergence sync. (Adds public mnmg_fit and batched_fit_omp templates.)
MG single-node refactor
cpp/src/cluster/detail/kmeans_mg.cuh
Rewrote MG fit loop to use per-iteration device accumulators, process_batch, finalize_centroids, allreduce-based merging, and unconditional inertia computation/use.
Common algorithm utilities
cpp/src/cluster/detail/kmeans_common.cuh
Added weightSum(...), process_batch(...), optional precomputed_centroid_norms parameters, reset_sums flag for centroid adjustments, and replaced some syncs with raft::resource::sync_stream(handle).
Distance compute changes
cpp/src/cluster/detail/minClusterDistanceCompute.cu
Added std::optional precomputed centroid norms parameter; avoid recomputing norms when provided and route kernels to use unified norms view; updated extern/template macros.
Cluster cost API
cpp/src/cluster/kmeans.cuh
Split cluster_cost into device-targeting overload (device_scalar_view) and host convenience overload (allocates device scalar then copies to host); removed previous device-allocation/copy pattern from device overload.
Fit dispatch logic
cpp/src/cluster/kmeans_fit_double.cu, cpp/src/cluster/kmeans_fit_float.cu
Runtime dispatch added (behind CUVS_BUILD_MG_ALGOS) to choose batched_fit_omp, mnmg_fit, or single‑GPU detail::fit based on raft::resources state.
Tests & build integration
cpp/tests/CMakeLists.txt, cpp/tests/cluster/kmeans_mg_batched.cu
Added MG batched GPU test target and comprehensive SNMG/MNMG parameterized tests validating weights, inits, inertia/ARI, and centroid agreement; CMake test added with NCCL dependency.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding multi-GPU batched support to KMeans. It is concise, specific, and accurately reflects the primary objective of the changeset.
Description check ✅ Passed The description directly relates to the changeset by explaining the multi-GPU support implementation, including both OpenMP (SNMG) and RAFT comms modes, batching, and fallback behavior. It provides meaningful context about the changes.
Linked Issues check ✅ Passed The PR addresses the core objectives from issue #1989: extending multi-GPU KMeans to accept host-resident matrices per rank, enabling out-of-core batching, and allowing distributed ranks to stream local data batches. Both SNMG and RAFT comms modes are implemented with proper weight handling and convergence logic.
Out of Scope Changes check ✅ Passed All code changes are scoped to KMeans multi-GPU batched implementation. The API extensions (new functions, updated signatures), weight handling, centroid computations, and convergence mechanisms are all directly aligned with the linked issue objectives. No unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/src/cluster/detail/kmeans_mg.cuh (1)

474-499: ⚠️ Potential issue | 🟠 Major

Normalize weights against the global sample count, not the local shard size.

After the allreduce, wt_sum is global, but target is still this rank's n_samples. That gives each rank a different scale factor and breaks equivalence with single-GPU weighted k-means whenever partitions are uneven or weights vary by rank. Reduce the sample counts as well and use global_n_samples / global_wt_sum here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/src/cluster/detail/kmeans_mg.cuh` around lines 474 - 499, The code
normalizes weights using a per-rank n_samples (local target) instead of the
global sample count, causing inconsistent scaling after comm.allreduce of
wt_aggr; update the logic to also reduce the local sample count to a global
count (e.g., perform an allreduce on n_samples into a global_n_samples variable
after or alongside raft::linalg::mapThenSumReduce/comm.allreduce), then set
target = static_cast<DataT>(global_n_samples) and compute the scale factor as
target / wt_sum before calling raft::linalg::map to rescale weight; modify
references to target and ensure synchronization (raft::resource::sync_stream)
remains correct.
🧹 Nitpick comments (1)
cpp/tests/cluster/kmeans_mg_batched.cu (1)

121-132: Add a zero-total-weight regression case.

The new MG path now rejects wt_sum == 0, but this fixture only exercises positive-weight modes. A dedicated all-zero host-weight case would pin the new validation behavior and guard against future regressions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tests/cluster/kmeans_mg_batched.cu` around lines 121 - 132, Add a
regression case that produces all-zero host weights so the MG path rejection of
wt_sum == 0 is exercised: inside the weight setup for the kmeans_mg_batched
fixture (where testparams_.weight_mode, h_sample_weight, h_sw and n_samples are
handled) add a branch for a new weight_mode value (e.g., weight_mode == 4) that
sets every h_sample_weight[i] to T(0) and constructs h_sw via
raft::make_host_vector_view<const T, int64_t>(h_sample_weight.data(),
n_samples); ensure the new mode is included in the test parameterization so the
test runs this all-zero-weight case.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/include/cuvs/cluster/kmeans.hpp`:
- Around line 170-175: Update the Doxygen doc block for the k-means public API
in cpp/include/cuvs/cluster/kmeans.hpp so it consistently documents the
RAFT-comms path (include "MPI-backed RAFT comms" alongside Dask/Ray) and mirrors
the detailed host-data/dispatch notes added to the float overload for the double
overload as well; specifically edit the comment around the kmeans function
overloads (the multi-GPU dispatch paragraph currently at lines ~170-175 and the
brief at ~206-207) to mention MPI-backed RAFT comms and copy the expanded
host-data behavior text from the float overload into the double overload's
Doxygen so both public overloads have identical, complete documentation.

In `@cpp/src/cluster/detail/kmeans_batched.cuh`:
- Around line 109-118: Guard against zero/invalid wt_sum before dividing: before
computing and returning target / wt_sum, check if wt_sum is <= 0 or not finite
(use std::isfinite) and handle it (e.g., RAFT_LOG_ERROR/RAFT_LOG_WARN that
weights are zero/invalid and return T{1} to skip scaling) so you never perform
target / wt_sum; keep the existing RAFT_LOG_DEBUG path and only do the division
after the validity check of wt_sum and the existing tolerance check on wt_sum vs
target.

In `@cpp/src/cluster/detail/kmeans_mg_batched.cuh`:
- Around line 183-207: Replace the exact-equality check and add a
positive-weight validation: after computing global_n and global_wt (from
d_n_local/d_wt), compute a tolerance tol = static_cast<T>(global_n) *
std::numeric_limits<T>::epsilon(); use if (std::abs(global_wt -
static_cast<T>(global_n)) > tol) to decide scaling instead of global_wt !=
static_cast<T>(global_n); before computing weight_scale assert/validate that
global_wt > T{0} (e.g. RAFT_EXPECTS or the project's preferred check) so
zero/negative total weights fail early; update references to weight_scale,
sample_weight, SNMG_ALLREDUCE, d_wt, global_wt and global_n accordingly.
- Around line 338-346: The call to
cuvs::cluster::kmeans::detail::accumulate_batch_centroids<T, IdxT> is passing
weight_per_cluster where the callee expects cluster_counts (sample counts per
cluster), causing wrong centroid accumulation; fix by supplying a proper
cluster_counts view (e.g., create or compute a cluster_counts device_vector/view
separate from weight_per_cluster) and pass that as the 6th argument instead of
weight_per_cluster, or if weighted k-means is intended, update the
accumulate_batch_centroids signature and its internal use to accept and handle
weight_per_cluster consistently; adjust callers accordingly (referenced symbols:
accumulate_batch_centroids, weight_per_cluster, batch_counts, centroid_sums).

In `@cpp/src/cluster/detail/kmeans_mg.cuh`:
- Around line 655-672: Gate the inertia-based stopping test behind
params.inertia_check and fix the comparison to measure relative improvement:
only run the check when params.inertia_check is true and n_iter[0] > 1 and
priorClusteringCost > DataT{0}; if curClusteringCost < priorClusteringCost
compute DataT rel = (priorClusteringCost - curClusteringCost) /
priorClusteringCost and set done = true when rel <= params.tol; ensure that cost
increases (curClusteringCost >= priorClusteringCost) do not count as convergence
(leave done false), and still update priorClusteringCost = curClusteringCost
afterwards.

---

Outside diff comments:
In `@cpp/src/cluster/detail/kmeans_mg.cuh`:
- Around line 474-499: The code normalizes weights using a per-rank n_samples
(local target) instead of the global sample count, causing inconsistent scaling
after comm.allreduce of wt_aggr; update the logic to also reduce the local
sample count to a global count (e.g., perform an allreduce on n_samples into a
global_n_samples variable after or alongside
raft::linalg::mapThenSumReduce/comm.allreduce), then set target =
static_cast<DataT>(global_n_samples) and compute the scale factor as target /
wt_sum before calling raft::linalg::map to rescale weight; modify references to
target and ensure synchronization (raft::resource::sync_stream) remains correct.

---

Nitpick comments:
In `@cpp/tests/cluster/kmeans_mg_batched.cu`:
- Around line 121-132: Add a regression case that produces all-zero host weights
so the MG path rejection of wt_sum == 0 is exercised: inside the weight setup
for the kmeans_mg_batched fixture (where testparams_.weight_mode,
h_sample_weight, h_sw and n_samples are handled) add a branch for a new
weight_mode value (e.g., weight_mode == 4) that sets every h_sample_weight[i] to
T(0) and constructs h_sw via raft::make_host_vector_view<const T,
int64_t>(h_sample_weight.data(), n_samples); ensure the new mode is included in
the test parameterization so the test runs this all-zero-weight case.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 715dc4bb-4fb5-4063-b0f2-203992e27484

📥 Commits

Reviewing files that changed from the base of the PR and between f2bffb6 and a461fe9.

📒 Files selected for processing (11)
  • cpp/include/cuvs/cluster/kmeans.hpp
  • cpp/src/cluster/detail/kmeans_batched.cuh
  • cpp/src/cluster/detail/kmeans_common.cuh
  • cpp/src/cluster/detail/kmeans_mg.cuh
  • cpp/src/cluster/detail/kmeans_mg_batched.cuh
  • cpp/src/cluster/detail/minClusterDistanceCompute.cu
  • cpp/src/cluster/kmeans.cuh
  • cpp/src/cluster/kmeans_fit_double.cu
  • cpp/src/cluster/kmeans_fit_float.cu
  • cpp/tests/CMakeLists.txt
  • cpp/tests/cluster/kmeans_mg_batched.cu

Comment thread cpp/include/cuvs/cluster/kmeans.hpp
Comment thread cpp/src/cluster/detail/kmeans_batched.cuh Outdated
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg_batched.cuh
Comment thread cpp/src/cluster/detail/kmeans_mg.cuh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[FEA] Multi-node Multi-GPU Kmeans (C++) to support new out-of-core batching

2 participants