This is the repo for paper Discovering Interpretable Algorithms by Decompiling Transformers to RASP.
All of our scripts are in crasp/scripts/patching/
Here we briefly describe what they are used for.
-
Model Training
-
train_new_models.pydefines classes for collator, models with BCE loss for training on formal languages etc. -
patching_data.pydefines all datasets for algorithmic and formal language tasks. -
train_new_models_search.pysearch over hyperparameter combinations and train models for a given task. -
train_new_models_checkpoints.py- save checkpoints during model training -
train_new_models_archs.pytrain models with more hyperparameter combinations, becausetrain_new_models_search.pymay stop early once it finds a generalizing model.
-
-
Causal Pruning
-
pruning_model.pydefines classes for models running with different kinds of computational graphs in different pruning stages. To understand the code, check Algorithm 1 in Li and Janson 2024-
PruningModelWithHooksdefines model where each components is prunable (stage 1) -
PruningModelWithHooksFullPathsdefines model where each path is prunable (stage 2) -
PruningModelWithHooksForQKdefines model where each QK product (selector) is prunable (stage 3) -
MaskSampler,MaskSamplerFullPaths, andMaskSamplerForQKdefine classes holding the parameters that control the distribution of masks -
OptimalAblationVectorsandOptimalQueryBiasVectorscontains various learned constants, e.g., optimal ablations - there are also various function for converting computational graph (config) between stages
-
-
patching.pyruns the causal pruning -
delineate_curve_for_model.pyimplements the automatic pruning coefficient searching. It automatically submits jobs to HTCondor (a job queuing system on clusters), and determines the new coefficients to use based on previous results. Each submitted job is a run forpatching.pywith certain hyperparameters.
-
-
Primitive Matching for MLPs
-
convert_mlp.pyimplements tracing activation variables, iterating over primitives to match with the best one. -
patching_utils.pydefines MLP primitives -
show_heatmap.pyimplements (1) the backup approach for explaining MLPs, that is, it saves input activation variables and effects on output logits or attention logits. (2) tracing activation variables and the op matrix ($A$ ) -
run_convert_mlp.pyruns the above process on the saved pruned models, and save new results for this step
-
-
Primitive Matching for Unembedding and Attention
-
attention_primitives_hook.pycontains a modified attention forward pass, which implements replacement with primitives. -
logits_primitives_hook.pycontains a forward hook for unembedding projection, which implements replacement with primitives. -
find_primitives_for_all_runs.pyis an endpoint used to run replacement of attention and unembedding with primitives. -
find_primitives.pycontains a function implementing general logic around replacement of attention and unembedding with primitives. -
try_primitives.pycontains a funcation implementing forward pass of a model with primitives and calculating metrics. -
primitives_classes.pydefines dataclasses used for primitives replacement. -
primitives_for_coefficients.pydefines a set of primitives. -
primitives_helpers.pycontains various helper functions. -
primitives_search.pyimplements algorithms for primitive replacement. 'greedy' stands for replacement with predefined primitives, and 'round' for rounding the original heatmaps. In the paper we use 'greedy_then_round', which stands for replacing with predefined primitives where possible, and rounding the rest. -
round_primitive.pyimplements a class and training loop for rounding heatmaps.
-
-
Generating D-RASP code
-
convert_to_code.pyimplements getting D-RASP code from a model. -
get_programs_and_heatmaps.pyis an endpoint used to copy all the relevant heatmaps to one place, generate D-RASP code and produce a latex file with program visulization. -
plot_for_mlp_interpretation.pycontains a function used for plotting interpretations of the MLPs, which were not replaced with primitives. Is used inget_programs_and_heatmaps.pyfile.
-
-
Results Visualization
-
streamlit_app/app.pyimplements an streamlit APP that visualize decompilation results, where you can check the pareto frontiers, decompiled programs, and various heatmaps. -
streamlit_app/draw_fig_xxx.pyandstreamlit_app/get_xxx.pydraw figures for various pareto frontiers and generalization results shown in the paper. -
plot_example_for_main_paper.pyis an endpoint used for plotting attention and unembedding primitives for the main paper. -
plot_func_for_main_paper.pyimplements the plotting function forplot_example_for_main_paper.py. -
plot_for_mlp_main_paper_d4.pyandplot_for_mlp_main_paper_sort.pyare used to plot MLP interpretation figures in the main paper. -
print_all_primitives.pygenerates and plots examples of all the predefined primitives. -
run_pretty_example.pyis an endpoint used to run a model on a specific short example to collect pretty heatmaps for activation variables.
-
The folder share/saved_models contains all the models on which we run our decompilation pipeline. We share them so that they can serve as a small benchmark. Future research may compare new methods with ours by applying on the same set of models.
Our implementation makes use of job queuing system, so if the environment is different, directly running it without modification or adaptation will not work. Nonetheless, it would be helpful to understand basic running order of our scripts.
- (optional) train models to be decompiled later. Or you can directly use models in
share/saved_models - run
delineate_curve_for_model.pyto perform pruning, which produces many pruned models with different degree of sparsity. Check frontiers withstreamlit_app/app.py, to see if LLNA holds. If the highest match accuracy becomes low in stage 2, set--split_mlp=Falseand run again. - run
run_convert_mlp.pyto replace MLPs with primitives - look at the frontiers again, create a
good_models.jsonfile that contains a dictionary in the same format as shown in this repo. For each series, this file selects a specific pruning run, for which we will then generate D-RASP code. Make sure that this pruning run is successful in stage 3, otherwise it cannot be visualized with the current code. If none of the pruning runs are successful in stage 3, keep the entry for the series ingood_models.jsonempty. For the paper, we selected pruning runs with the smallest number of edges which achieved match accuracy of at least 0.9 in stage 3. - run
find_primitives_for_all_runs.pyto replace attention and unembedding in the pruning runs fromgood_models.jsonwith primitives. Specify correct paths in this file before running. - run
run_pretty_example.pyto collect pretty plots for heatmaps of activation variables used for program visualization later. Update paths and dictionaries in this file before running. - run
get_programs_and_heatmaps.pyto create latex files with visualizations of programs. This script copies relevant heatmaps to a specified location, generates D-RASP code for each model, and creates latex files with heatmaps and programs. The output of this script is used in the Appendix section of the paper.