Skip to content

Make tokenizer_path to be non-mandatory#3367

Open
A9isha wants to merge 1 commit intomainfrom
anisha-tokenizer_path
Open

Make tokenizer_path to be non-mandatory#3367
A9isha wants to merge 1 commit intomainfrom
anisha-tokenizer_path

Conversation

@A9isha
Copy link
Collaborator

@A9isha A9isha commented Mar 10, 2026

Description

Simplify the parameters for running on MaxText by removing tokenizer_path as a required argument.

FIXES: b/490520651

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Ran locally on a v5p-8 the following commands:

# With tokenizer path
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml   model_name=llama3.1-8b-Instruct   tokenizer_path=meta-llama/Llama-3.1-8B-Instruct   load_parameters_path=/path/to/checkpoint   run_name=maz-8b-$RANDOM   base_output_directory=/path/to/gcs/bucket   hf_access_token=<HF_TOKEN> dataset_name=gsm8k steps=4 rollout_tensor_parallelism=-1 rollout_data_parallelism=1 rollout_expert_parallelism=1

# Without tokenizer path
python3 -m src.maxtext.trainers.post_train.rl.train_rl src/maxtext/configs/post_train/rl.yml   model_name=llama3.1-8b-Instruct   tokenizer_path=meta-llama/Llama-3.1-8B-Instruct   load_parameters_path=/path/to/checkpoint   run_name=maz-8b-$RANDOM   base_output_directory=/path/to/gcs/bucket   hf_access_token=<HF_TOKEN> dataset_name=gsm8k steps=4 rollout_tensor_parallelism=-1 rollout_data_parallelism=1 rollout_expert_parallelism=1

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@SurbhiJainUSC
Copy link
Collaborator

Can we do the similar change for SFT?

@A9isha A9isha force-pushed the anisha-tokenizer_path branch from f3cf026 to e176cf9 Compare March 10, 2026 18:36
@codecov
Copy link

codecov bot commented Mar 10, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@A9isha A9isha force-pushed the anisha-tokenizer_path branch 5 times, most recently from a9659e3 to ffa2895 Compare March 10, 2026 21:01
@A9isha
Copy link
Collaborator Author

A9isha commented Mar 10, 2026

Can we do the similar change for SFT?

Yes done now

new_value = ""

if key == "tokenizer_path" and new_value is None:
new_value = HF_IDS.get(raw_keys["model_name"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if MODEL_NAME is not present in HF_IDS?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my response here

"llama2-13b",
"llama2-70b",
"llama3-8b",
"llama3.1-8b-Instruct",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of explicitly adding Instruct model name to this list, can we somehow derive it from the base model? In future, we might need Instruct model for gemma and then we would have to update this list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have added qwen3-omni-30b-a3b-Instruct to HF_IDS, but it is not present in ModelName list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the idea for this PR is to allow on the fly tokenizer_path calculation for the popular models used in the tutorials to make it easier for first time users. We still have the option of passing tokenizer_path. So, for the models which are not supported for this on the fly calculations we will default to using the previous maxtext/assets/tokenizers/tokenizer.llama2 default and if the user intended to use something else they would pass tokenizer_path

Great point about qwen3-omni-30b-a3b-Instruct, I will remove that change.

@A9isha A9isha force-pushed the anisha-tokenizer_path branch 4 times, most recently from 4c0f0ec to ba2ad97 Compare March 10, 2026 22:49
@A9isha A9isha marked this pull request as draft March 10, 2026 22:55
@A9isha A9isha force-pushed the anisha-tokenizer_path branch from ba2ad97 to 2759471 Compare March 10, 2026 23:17
@A9isha A9isha marked this pull request as ready for review March 10, 2026 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants