published · 2025-12-24

Attention Alignment Outperforms Logit Distillation for LLM Compression

When compressing a large language model into a smaller one, should you just teach the small model to match the large model's outputs — or also show it what the large model is paying attention to? We ran the comparison systematically and found that attention maps transfer surprisingly well, cutting error rates by 26%. But combining attention with other internal signals adds nothing.

Jack Large, Madelyn Sarbin

Running a 7-billion-parameter language model in production is expensive. It requires serious hardware, significant memory, and meaningful compute cost per query. One well-established solution is knowledge distillation: train a smaller model to behave like a larger one, so you get most of the capability at a fraction of the cost.

The question this paper asks is: what exactly should you transfer? The obvious answer is outputs — just teach the small model to produce the same predictions the big model would. But large models encode rich intermediate reasoning: which words they're attending to, how they're representing meaning internally. Should you transfer those signals too?

two ways to distill

Black-box distillation treats the teacher model as a black box. You run the teacher on your training data, record its output probabilities for each possible answer, and train the student to match those distributions. The student never sees inside the teacher.

White-box distillation opens the box. In addition to matching outputs, the student is also trained to match internal signals from the teacher: its hidden states (the vector representations of meaning at each layer), its attention maps (which tokens the model is "looking at" when processing each word), or both.

We compared four configurations:

  1. Black-box only (KL divergence on output logits)
  2. Black-box + hidden state alignment
  3. Black-box + attention map alignment
  4. Black-box + hidden states + attention (combined)

setup

We compressed Llama-2-7B (7 billion parameters) into TinyLlama-1.1B (1.1 billion) — a 6.4× size reduction. We evaluated on three tasks: sentiment classification (SST-2), broad reasoning across subjects (MMLU), and grade-school math word problems (GSM8K). Each configuration was run with 7 different random seeds to get reliable variance estimates, not just single-run results.

what we found

Attention alignment wins, clearly and significantly.

MethodAccuracyStd DevError Reduction
Attention95.56%1.43%26%
Combined95.56%1.43%26%
Hidden state94.71%1.05%12%
Black-box93.98%2.98%

The attention method reduces error rate from 6.02% to 4.44% — a 26% reduction. That difference is statistically significant (p < 0.01 across 7 seeds per method). Hidden state alignment helps somewhat on its own (12% error reduction), but when you add it to attention, it adds nothing. The combined method is identical to attention-only.

why attention works better than hidden states

The two models have different hidden state dimensions: TinyLlama uses 2048-dimensional vectors, Llama uses 4096. To align them, you have to learn a linear projection mapping one to the other. That projection may not transfer meaning cleanly — it's bottlenecked by the dimensionality mismatch.

Attention maps, on the other hand, require no projection. Both models have 32 attention heads, so the maps are directly comparable. Attention captures relational information — which words are relevant to which other words — which may transfer more naturally between architectures than the specific numeric representations of meaning.

That's our hypothesis. It's plausible but not proven — you'd need probing experiments and visualization work to confirm it.

where it doesn't help: math

The improvement is not uniform across tasks. On SST-2 (sentiment), attention distillation improves accuracy by 2.40 points. On MMLU (general reasoning), by 0.79 points. On GSM8K (math), by 0.09 points — essentially nothing.

This probably isn't a coincidence. Mathematical reasoning in these models tends to unfold as a sequential chain of computation through hidden states: each step builds on the previous one. Attention patterns capture global token relationships, which matters a lot for understanding sentiment or contextual meaning, but may not capture the step-by-step arithmetic trace that math problems require. The same signal that transfers well for language understanding may simply not be the relevant signal for math.

what to do with this

If you're compressing a language model and your teacher and student have matched attention head counts (which is increasingly common in open model families), use attention distillation. It's a modest overhead during training — you're computing one extra MSE loss term — and it provides a statistically significant, consistent improvement. Don't bother adding hidden state alignment on top; it adds complexity for no measurable gain.

The caveat is that this was tested on one teacher-student pair. Whether it generalizes to mismatched architectures, different model families, or very different task distributions is open.