A Partial Replication of Distributed Alignment Search

This will probably be the first of many mechinterp-focused posts in the next few days. There are a couple of applications that I care about that are due very soon, and I want to share the work and thinking I’ve done as relevant to those applications.

The Colab Notebook and Slide Deck are linked here for convenience.

Disclaimer: The Colab notebook was coded by Claude Code. I separately presented the DAS paper for a class with Logan Graves and Siddarth Bhatia.

In November of this year, I applied to do research under the mentorship of Aryaman Arora, a third-year Computer Science PhD at Stanford and part-time researcher at Transluce. As part of the application process, I had to replicate a result from any paper of my choice. Since I had already read the entirety of the distributed alignment search (DAS) paper for a final presentation in a neuroscience x mechanistic interpretability class at Stanford, I thought I may as well try and replicate it for the application.

Causal Abstraction

Distributed Alignment Search (DAS) is a method for causal abstraction for mechanistic interpretability; thus, it makes sense to tldr causal abstraction before explaining how DAS works. The notation used to formally describe DAS and causal abstraction is obscenely verbose, and I hope to avoid this by hand-waving some of the formality in favor of intuitive simplicity.

In causal abstraction, we have a low-level model L that successfully achieves a particular task, and we hypothesize that L implements a high-level causal model H that identically achieves the same task. Though L and H are functionally equivalent—for any input x, we have L(x)=H(x) by assumption—it remains to be shown that L and H are computationally equivalent—that L and H perform the same algorithm under the hood.

A brief tangent: why do we care about causal abstraction in mechanistic interpretability? One of the primary goals of mechanistic interpretability for large neural networks is to be able to explain how it is that these networks perform so well on such a wide variety of tasks. For models to consistently perform well on a given task across a variety of inputs, it’s likely that the model somehow implements a generalized algorithm that completes the task. Causal abstraction in mechanistic interpretability aims to abstract away specifics details (such as how the model might represent or manipulate the inputs internally) into a high-level causal model. Different neural networks might represent the same input in different ways, but if they perform the same computations and algorithms, they have the same high-level causal model. By performing causal abstraction correctly, we get a clean view of how the neural network completes the task without worrying about the specific representations or other auxiliary operations.

So how do we perform causal abstraction and show computational equivalence between L and H? Well, if L and H really do implement the exact same algorithm under the hood, then you should be able to find a direct correspondence between how L and H represent intermediate variables (resulting from intermediate computations). An example goes a long way in clarifying this. Say you have some model L that can autocomplete any expression of the form “a + b + c = “. You might hypothesize that L somehow implements a high-level model H that does “add a + b, then add c.” If this were true, then that means we should be able to find some representation of “a + b” inside of L somewhere. Notice that multiple high-level models could complete this task - there might be another model H’ that adds “b + c” first. This is a critical point and is discussed later in this article.

So how do we know that we’ve actually found a true correspondence between intermediate variables of L and H? Again, an example is illustrative. Going back to our “a + b + c =” autocomplete task, let’s say we have two examples: “1 + 3 + 5 =” and 9 + 7 + 5=”. If L actually does add the first two variables first, then L(“1 + 3 + 5 =”) should have some representation of “4” (coming from 1 + 3) inside of its computations. Likewise, L(“9 + 7 + 5=”) should also have some representation of “16” (coming from 9 + 7). Say we think we know where L stores the intermediate variable of “a + b” in itself. Then if we run L(“1 + 3 + 5=”) but replace the representation of “4” with the representation of “16” in L, then we should actually see that L(“1 + 3 + 5=”) outputs 21. This is called an interchange intervention, and it’s the same idea underlying claims of causation in science—if we change one thing, hold everything else constant, and see a change in output, then that one thing we changed must be responsible for the change in output.

If we do many interchange interventions across all of the intermediate variables and find that the outputs of L change exactly as we expect, we say that we have “high interchange intervention accuracy”. If we have high interchange intervention accuracy, then we have good evidence to believe that L indeed implements H via the intermediate variables we identified.

DAS

So what does does DAS do? DAS is a specific method for finding the correct internal representation of intermediate variables in L when L is a neural network. It expects that low level models store these interemdiate variables in disjoint orthogonal subspaces of activation space, but that the default neuron basis is not a privileged one. DAS believes that the correct rotation matrix will rotate the neuron basis into exactly the orthogonal basis whose subspaces store intermediate variables. Thus, this approach trains a rotation matrix to maximize interchange intervention accuracy (you can use various forms of loss to achieve this—cross-entropy is a natural one) at each layer to find which layers might store relevant intermediate variables. We set in advance some subspaces of our rotation matrix to store specific intermediate variables so that the rotation matrix learns to find the subspace that corresponds to the higher level intermediate variables we care about. Then we see which layers and neurons have the highest interchange intervention accuracy after being rotated into this new basis, and we claim that this is how L stores the corresponding high-level intermediate variable.

I mentioned that hypothesis-based causal abstraction (DAS is an example) has a major weakness, which is apparent in the name: a hypothesis for the higher level causal model H is required in order to execute the causal abstraction method. In other words, DAS and other hypothesis-based causal abstraction methods are verification protocols for causal abstraction. I’ve not personally read up on the causal abstraction literature, but I imagine that there are approaches for causal abstraction that discover high-level models in neural networks (or other equivalent low-level models) in an unsupervised manner (many people are interested in building interpretability agents for such a process). Scaling up such an approach efficiently would be revolutionary for mechanistic interpretability.

What I Replicated

I decided to replicate their results for the hierarchical equality task (see this for visual reference, or the DAS paper), which has a natural high level model H to solve. I started with a randomly initialized network (our L) and showed that training DAS for 10 epochs did not result in any discovery of intermediate variables. I then trained L on the hierarchical equality task until it achieved above 99.9% accuracy. I then trained a new rotation matrix to see if DAS found any intermediate variables, and it indeed found intermediate variables that had an interchange intervention accuracy consistently above 95% across different initializations, showing that DAS does indeed find intermediate variables with high interchange intervention accuracy.

Thoughts and Reflections

The most difficult part of this project by far was understanding the methodology—even without Claude Code, it would have been far easier to code up the logic than to understand causal abstraction and the motivations of DAS. This was my first real work tinkering with models and reading a mechinterp paper in depth. Causal abstraction seems really important for mechanistic interpretability and is exciting to me, but I’m not the most convinced by DAS. The original paper showed that expanding the hidden dimension of the neural network allowed DAS to find small but non-negligible strucutre even in randomly initialized neural networks (64% interchange intervention accuracy on the hierarchical equality task). Moreover, DAS relies on a rotation matrix, which preserves the dimension size when looking for semantic meaning. However, the superposition literature strongly indicates that the model holds far more concepts than it has dimensions, so it’s not clear that DAS shoud be sufficient for more complicated tasks, especially those with more intermediate variables. The discussion on supervised vs unsupervised causal abstraction is also a weakness of DAS.

For a mere replication project that was mostly about understanding a limited methodology, I enjoyed myself and the process far more than I expected. Seeing the successful replication scratched the same kind of itch that I get when successfully coding something up, and the debugging process (even if Claude did most of the debugging) felt very familiar and comfortable. I want to do more and see if this joy scales up to when I start doing more substantial projects.