TranformerLens is a Python library for Mechanistic Interpretability. It’s got some great tutorials… but they are all kinda verbose. Here’s a cheatsheet of all the common things you’ll want from the library. Click the links for more details.
Table of Contents
Setup
!pip install git+https://github.com/TransformerLensOrg/TransformerLens
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils
Creating a model
model = HookedTransformer.from_pretrained("gpt2-small")
cfg = HookedTransformerConfig(...)
model = HookedTransformer.from_config(cfg)
Models have very similar arguments and methods to Torch modules.
Full list of pretrained models. Example parameters for HookedTransformerConfig.
Running a model
some_text = "..."
logits = model(some_text)
logits, loss = model(some_text, return_type="both")
logits, cache = model.run_with_cache(some_text)
model.generate(some_text, max_new_tokens=50, temperature=0.7, prepend_bos=True)
Weights
model.blocks[0].attn.W_Q # shape (nheads, d_model, d_head)
model.W_Q # shape (nlayers, nheads, d_model, d_head)
model.b_Q
Weight matrices multiply on the right, i.e. they have shape [input, output]
,
See diagram or reference for all weights.
Working with ActivationCache
# Fully qualified
cache["blocks.0.attn.hook_pattern"]
# Short code using utils.get_act_name(name, layer[, layer_type])
layer = 0
cache["embed"] # token embeddings
cache["q", layer] # Query vectors for nth transformer block
# These two need you to say which LayerNorm they refer to
cache["normalized", layer, "ln1"]
cache["scale", layer, "ln2"]
# Final LayerNorm
cache["normalized"]
cache["scale"]
Can use layer=-1 for last layer.
Full set of short names:
Name | aka | shape (excl. batch) |
---|---|---|
embed | seq d_model | |
pos_embed | seq d_model | |
resid_pre layer | seq d_model | |
scale layer ln1 | seq 1 | |
normalized layer ln1 | seq d_model | |
q layer | query | seq head_idx d_head |
k layer | key | seq head_idx d_head |
v layer | value | seq head_idx d_head |
attn_scores layer | attn_logits | head_idx seqQ seqK |
pattern layer | attn | head_idx seqQ seqK |
z layer | seq head_idx d_head | |
result layer | seq head_dx d_model | |
attn_out layer | seq d_model | |
resid_mid layer | seq d_model | |
scale layer ln2 | seq d_model | |
normalized layer ln2 | seq d_model | |
pre layer | mlp_pre | seq 4*d_model |
post layer | mlp_post | seq 4*d_model |
mlp_out layer | seq d_model | |
resid_post layer | seq d_model | |
scale | seq d_model | |
normalized | seq d_model |
See diagram for what each activation means.
All tensors start with a batch dimension, unless you ran the model with remove_batch_dim=True
.
CircuitsVis
!pip install circuitsvis
import circuitsvis as cv
attn_pattern = cache["pattern", 0]
tokens = model.to_str_tokens(some_text)
cv.attention.attention_patterns(tokens=tokens, attention=attn_pattern)
Hooks
head_index_to_ablate = 4
def head_ablation_hook(value: torch.Tensor, hook: HookPoint) -> torch.Tensor:
value[:, :, head_index_to_ablate, :] = 0.
return value
model.run_with_hooks(some_text, fwd_hooks=[
("blocks.0.attn.hook_v", head_ablation_hook)
])
# Multiple hook points
model.run_with_hooks(some_text, fwd_hooks=[
(lambda name: name.endswith("v"), head_ablation_hook)
])
# Using partial
from functools import partial
def head_ablation_hook2(value: torch.Tensor, hook: HookPoint, head_index: int) -> torch.Tensor:
value[:, :, head_index, :] = 0.
return value
model.run_with_hooks(some_text, fwd_hooks=[
("blocks.0.attn.hook_v", partial(head_ablation_hook2, head_index=4))
])
Don’t forget you can use utils.get_act_name to get hook names easily.
Tokens
some_text = "The fat cat"
model.get_token_position(" cat", some_text)
tokens = model.to_tokens(some_text)
model.to_string(tokens)
model.to_str_tokens(some_text)
Many token methods accept str
or List[str]
or a tensor of Long with shape (batch, pos)
.
Note that a “BoS” token is prepended during tokenization, this is a common gotcha. Consider using kwarg prepend_bos=False
.