Skip to content
/ rat Public

Reverse Attention Tracer: A lightweight API to visualize which words influenced your LLM generations

License

Notifications You must be signed in to change notification settings

ovshake/rat

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Reverse Attention Tracer Logo

[DOCS]

Reverse Attention Tracer (RAT)

A Python package for tracing attention paths backward through transformer models, with interactive D3.js Sankey visualization.

Features

  • Reverse attention tracing: Trace which tokens most influence a target token by following attention paths backward
  • Beam search: Efficiently explore multiple high-probability paths through the attention matrix
  • Interactive visualization: D3.js-powered Sankey diagrams with zoom, pan, and click-to-highlight
  • Qwen2 support: Optimized for Qwen2 family models (works with other HuggingFace transformers)

Installation

pip install reverse-attention

Or install from source:

git clone https://github.com/ovshake/rat
cd rat
pip install -e .

Quick Start

from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

# Create tracer
tracer = ReverseAttentionTracer(model, tokenizer)

# Trace attention from the last token
result = tracer.trace_text("The quick brown fox jumps over the lazy dog.")

# Print top attention paths
for i, path in enumerate(result.paths_text):
    print(f"Beam {i+1}: {path}")

# Generate interactive visualization
tracer.render_html(result, "output/", open_browser=True)

API Reference

ReverseAttentionTracer

The main class for tracing attention paths.

tracer = ReverseAttentionTracer(model, tokenizer, device=None, dtype=None)

Parameters

  • model: HuggingFace transformer model
  • tokenizer: Corresponding tokenizer
  • device: Device to run on (defaults to model's device)
  • dtype: Data type for computation (defaults to model's dtype)

trace()

Trace attention paths backward from a target position.

result = tracer.trace(
    input_ids,              # Input token IDs [1, seq_len]
    target_pos=-1,          # Position to trace from (supports negative indexing)
    attention_mask=None,    # Optional attention mask
    layer=-1,               # Layer index (supports negative indexing)
    top_beam=5,             # Number of beams to keep
    top_k=5,                # Top-k predecessors per step
    min_attn=0.0,           # Minimum attention threshold
    agg_heads="mean",       # Head aggregation: "mean", "max", "none"
    length_norm="avg_logprob",  # Score normalization
    stop_at_bos=True,       # Stop at BOS tokens
    bos_token_id=None,      # Override BOS token ID
)

trace_text()

Convenience method that tokenizes text before tracing.

result = tracer.trace_text(
    "Your text here",
    target_pos=-1,
    **kwargs  # Same as trace()
)

render_html()

Generate interactive HTML visualization.

html_path = tracer.render_html(
    result,                 # TraceResult from trace()
    out_dir="output/",      # Output directory
    open_browser=False,     # Open in browser after generation
)

TraceResult

The result object returned by trace():

  • seq_len: Sequence length
  • target_pos: Target position (resolved to positive index)
  • layer: Layer index (resolved to positive index)
  • top_beam: Number of beams used
  • top_k: Top-k value used
  • tokens: List of all tokens in sequence
  • beams: List of BeamPath objects
  • sankey: SankeyData for visualization
  • paths_text: Human-readable path descriptions

BeamPath

A single attention path:

  • positions: Token positions in sequence
  • tokens: Token strings
  • token_ids: Token IDs
  • edge_attns: Attention weights along edges
  • score_raw: Raw cumulative log score
  • score_norm: Length-normalized score

Score Normalization

The length_norm parameter controls how path scores are normalized:

  • "none": No normalization (raw cumulative log probability)
  • "avg_logprob": Divide by path length (geometric mean, default)
  • "sqrt": Divide by sqrt(path length)
  • "pow:α": Divide by path_length^α (e.g., "pow:0.7")

Head Aggregation

The agg_heads parameter controls how attention heads are combined:

  • "mean": Average attention across all heads (default)
  • "max": Maximum attention across heads
  • "none": Keep all heads separate (returns 3D attention tensor)

Example Script

Run the demo script:

python examples/demo_qwen2.py --text "Your text here" --open-browser

Options:

  • --model: Model name or path (default: Qwen/Qwen2-0.5B)
  • --text: Text to trace
  • --target-pos: Target position (default: -1)
  • --layer: Layer index (default: -1)
  • --top-beam: Number of beams (default: 5)
  • --top-k: Top-k predecessors (default: 5)
  • --output: Output directory (default: output)
  • --open-browser: Open visualization in browser
  • --device: Device to use (default: auto)

Visualization Features

The generated HTML visualization includes:

  • Zoom/Pan: Scroll to zoom, drag to pan
  • Click to highlight: Click nodes to highlight connected paths
  • Beam filter: Dropdown to filter by specific beam
  • Info panel: Click elements to see details (position, token, attention weights)
  • Color coding: Beams are color-coded for easy identification

Development

Install dev dependencies:

pip install -e ".[dev]"

Run tests:

pytest tests/ -v

License

MIT License

Documentation

For detailed documentation, tutorials, and API reference, visit the documentation site.

About

Reverse Attention Tracer: A lightweight API to visualize which words influenced your LLM generations

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors