Skip to content

Fix: Correct Optax Adam state index out of bounds in notebook 01#5

Open
monsur161 wants to merge 1 commit intorcrowe-google:mainfrom
monsur161:patch-1
Open

Fix: Correct Optax Adam state index out of bounds in notebook 01#5
monsur161 wants to merge 1 commit intorcrowe-google:mainfrom
monsur161:patch-1

Conversation

@monsur161
Copy link
Copy Markdown

While working through the 01 - JAX AI Stack.ipynb notebook, I noticed a minor bug in the Optax optimizer state verification block.

Currently, the script checks adam_optax_internal_state[1] for the Adam mu state. However, because optax.adam chains ScaleByAdam and ScaleByLearningRate, the momentum tracking variables (count, mu, nu) are actually located in the ScaleByAdamState at index [0]. Index [1] currently returns an EmptyState(), which causes the hasattr check to fail and triggers the warning block.

I updated the hardcoded indices from [1] to [0] so the state validation correctly passes.

Updated hardcoded index from [1] to [0] to correctly access the ScaleByAdamState, as the Optax internal state tuple currently returns EmptyState() at index [1].
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant