Neuron to Graph
Interpreting Language Model Neurons at Scale
- Alex Foote1*, Neel Nanda2, Esben Kran1, Ionnis Konstas3, Shay Cohen4, Fazl Barez1,4,5*
1Apart Research 2Independent 3School of Mathematical and Computer Sciences Heriot-Watt University 4School of Informatics, University of Edinburgh 5University of Oxford
* Equal contribution
Understanding the function of individual neurons within language models is essential for mechanistic interpretability research. We propose Neuron to Graph (N2G), a tool which takes a neuron and its dataset examples, and automatically distills the neuron's behaviour on those examples to an interpretable graph. This presents a less labour intensive approach to interpreting neurons than current manual methods, that will better scale these methods to large language models (LLMs). We use truncation and saliency methods to only present the important tokens, and augment the dataset examples with more diverse samples to better capture the extent of neuron behaviour. These graphs can be visualised to aid manual interpretation by researchers, but can also output token activations on text to compare to the neuron's ground truth activations for automatic validation. N2G represents a step towards scalable interpretability methods by allowing us to convert neurons in an LLM to interpretable representations of measurable quality.
Figure 1: Overall architecture of N2G. Activations of the target neuron on the dataset examples are retrieved (neuron and activating tokens in red). Prompts are pruned and the importance of each token for neuron activation is measured (important tokens in blue). Pruned prompts are augmented by replacing important tokens with high-probability substitutes using BERT. The augmented set of prompts are converted to a graph. The output graph is a real example which activates on the token “except” when preceded by any of the other tokens.
Figure 2: An example of a graph built from Neuron 2 of Layer 1 of the model.
Table 1: Precision, recall and F1-score of the neuron graphs' token-level predictions of neuron firing compared to ground truth on held-out test data, for 50 random neurons from each layer of the model. Tokens on which the real neuron fired and tokens on which it didn't fire are evaluated separately as there are generally many more tokens on which a neuron didn't fire, making it trivially easy to get near-perfect scores by always predicting the neuron will not fire.
Work in progress