A Python package for tracing attention paths backward through transformer models, with interactive D3.js Sankey visualization.
- Reverse attention tracing: Trace which tokens most influence a target token by following attention paths backward
- Beam search: Efficiently explore multiple high-probability paths through the attention matrix
- Interactive visualization: D3.js-powered Sankey diagrams with zoom, pan, and click-to-highlight
- Qwen2 support: Optimized for Qwen2 family models (works with other HuggingFace transformers)
pip install reverse-attentionOr install from source:
git clone https://github.com/ovshake/rat
cd rat
pip install -e .from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
# Create tracer
tracer = ReverseAttentionTracer(model, tokenizer)
# Trace attention from the last token
result = tracer.trace_text("The quick brown fox jumps over the lazy dog.")
# Print top attention paths
for i, path in enumerate(result.paths_text):
print(f"Beam {i+1}: {path}")
# Generate interactive visualization
tracer.render_html(result, "output/", open_browser=True)The main class for tracing attention paths.
tracer = ReverseAttentionTracer(model, tokenizer, device=None, dtype=None)model: HuggingFace transformer modeltokenizer: Corresponding tokenizerdevice: Device to run on (defaults to model's device)dtype: Data type for computation (defaults to model's dtype)
Trace attention paths backward from a target position.
result = tracer.trace(
input_ids, # Input token IDs [1, seq_len]
target_pos=-1, # Position to trace from (supports negative indexing)
attention_mask=None, # Optional attention mask
layer=-1, # Layer index (supports negative indexing)
top_beam=5, # Number of beams to keep
top_k=5, # Top-k predecessors per step
min_attn=0.0, # Minimum attention threshold
agg_heads="mean", # Head aggregation: "mean", "max", "none"
length_norm="avg_logprob", # Score normalization
stop_at_bos=True, # Stop at BOS tokens
bos_token_id=None, # Override BOS token ID
)Convenience method that tokenizes text before tracing.
result = tracer.trace_text(
"Your text here",
target_pos=-1,
**kwargs # Same as trace()
)Generate interactive HTML visualization.
html_path = tracer.render_html(
result, # TraceResult from trace()
out_dir="output/", # Output directory
open_browser=False, # Open in browser after generation
)The result object returned by trace():
seq_len: Sequence lengthtarget_pos: Target position (resolved to positive index)layer: Layer index (resolved to positive index)top_beam: Number of beams usedtop_k: Top-k value usedtokens: List of all tokens in sequencebeams: List ofBeamPathobjectssankey:SankeyDatafor visualizationpaths_text: Human-readable path descriptions
A single attention path:
positions: Token positions in sequencetokens: Token stringstoken_ids: Token IDsedge_attns: Attention weights along edgesscore_raw: Raw cumulative log scorescore_norm: Length-normalized score
The length_norm parameter controls how path scores are normalized:
"none": No normalization (raw cumulative log probability)"avg_logprob": Divide by path length (geometric mean, default)"sqrt": Divide by sqrt(path length)"pow:α": Divide by path_length^α (e.g.,"pow:0.7")
The agg_heads parameter controls how attention heads are combined:
"mean": Average attention across all heads (default)"max": Maximum attention across heads"none": Keep all heads separate (returns 3D attention tensor)
Run the demo script:
python examples/demo_qwen2.py --text "Your text here" --open-browserOptions:
--model: Model name or path (default: Qwen/Qwen2-0.5B)--text: Text to trace--target-pos: Target position (default: -1)--layer: Layer index (default: -1)--top-beam: Number of beams (default: 5)--top-k: Top-k predecessors (default: 5)--output: Output directory (default: output)--open-browser: Open visualization in browser--device: Device to use (default: auto)
The generated HTML visualization includes:
- Zoom/Pan: Scroll to zoom, drag to pan
- Click to highlight: Click nodes to highlight connected paths
- Beam filter: Dropdown to filter by specific beam
- Info panel: Click elements to see details (position, token, attention weights)
- Color coding: Beams are color-coded for easy identification
Install dev dependencies:
pip install -e ".[dev]"Run tests:
pytest tests/ -vMIT License
For detailed documentation, tutorials, and API reference, visit the documentation site.
