Skip to content

Add Indexer training loss#3364

Open
RissyRan wants to merge 1 commit intomainfrom
indexer_loss
Open

Add Indexer training loss#3364
RissyRan wants to merge 1 commit intomainfrom
indexer_loss

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Mar 10, 2026

Description

Add Indexer training loss for DeepSeek V3.2:

  • Add loss for Indexer module based on ref paper (another reference: Megatron LM - link)
  • Update all Index to Indexer to align in MaxText
  • For fully integration, we will need 3 stages, and this is the 1st one.
  • Add unit tests for sanity check.

Tests

  • Expect all tests are green.
  • Verify deepseek32_vs_reference_test same as main branch (b/491486716)
  • End to end training test is functional - link
  • Tensorboard to check loss is recorded properly - link

Run cmd

export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/no-mhc-change/$(date +%Y-%m-%d)
export RUN_NAME=ds-loss-verify

python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=${RUN_NAME} per_device_batch_size=8 enable_checkpointing=false model_name=deepseek-custom ici_fsdp_parallelism=4 steps=100 max_target_length=1024 async_checkpointing=false dtype=bfloat16 weight_dtype=bfloat16 scan_layers=True attention=flash train_split=train dataset_type=hf hf_path='HuggingFaceFW/fineweb-edu' hf_name=default hf_access_token=<> enable_tensorboard=true tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3.2 mhc_expansion_rate=4 data_shuffle_seed=1234 scan_layers=True use_tokamax_splash=True indexer_loss_scaling_factor=0.01 log_period=5 enable_tensorboard=True

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 14.28571% with 48 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_mla.py 10.34% 26 Missing ⚠️
src/maxtext/trainers/pre_train/train.py 15.38% 10 Missing and 1 partial ⚠️
src/maxtext/layers/attention_op.py 18.18% 7 Missing and 2 partials ⚠️
src/maxtext/common/metric_logger.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements the DeepSeek V3.2 Indexer training loss within MaxText. The core logic for calculating and propagating the KL divergence loss correctly integrates with the existing MTP and MoE loss aggregation patterns.

🔍 General Feedback

  • The indexer_loss is accumulated and scaled correctly through train.py and gradient accumulation mechanisms.
  • Tests thoroughly cover both dense and sparse configurations.
  • Be mindful of mask aggregation; applying attention_mask multiple times via indexer_mask can lead to numerical instability (-inf resulting in NaN loss) during softmax operations, which has been highlighted inline.

@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Mar 10, 2026
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Mar 10, 2026
@RissyRan RissyRan force-pushed the indexer_loss branch 2 times, most recently from 4fbb511 to 0b401fa Compare March 10, 2026 22:46
Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the thoughtful implementation of indexer loss and comprehensive testing! Overall looks good to me. Left some minor comments.

@RissyRan RissyRan force-pushed the indexer_loss branch 3 times, most recently from eceefcb to ce58515 Compare March 11, 2026 18:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants