AVnorm

[md]

“Attention normalizes the wrong norm” showed that attention output variance collapses as context grows. L2-normalized softmax fixes it in theory, but existing architectures are tuned for L1. So, swapping the norm breaks hyperparameters.

There’s a simpler fix: normalize the output.

av = attn @ v  # or F.scaled_dot_product_attention(...)
av = GroupNorm(num_groups=heads, num_channels=dim, affine=False)(av)  # per-head LayerNorm

AVNorm works with existing architectures without retuning components by normalizing each head’s output independently.

Evidence

Two nanogpt models trained on FineWeb-Edu for 4000 steps:

At training length, AVnorm matches the baseline, but beyond training length, the gap widens. At 2x context, the baseline’s loss jumps from 3.3 to 7.5 while AVnorm only increases to 4.2.

extrapolation.png

When extrapolating, the baseline model fails abruptly, while AVnorm degrades smoothly.

Retrofitting

L2 Softmax made one key mistake: requiring retraining and retuning. AVNorm reuses existing hyperparameters.

Let’s push further: instead of just hyperparameters, reuse the model itself. Earlier experiments trained from scratch. This time, pretrain for 3000 steps, finetune for 1000, and compare L1 continuation versus adding AVnorm.

retrofit.png

By 500 steps, AVNorm matches the L1 baseline.

Understanding AVnorm

Softmax normalizes L1 (weights sum to one), but variance depends on L2 (sum of squares). As context grows, variance collapses assuming independent inputs. AVnorm sidesteps this assumption by normalizing outputs instead.

A per-head normalization achieves what L2 softmax aimed to achieve: constant variance per head regardless of attention pattern and context. If this holds, per-head RMSNorm should perform similarly to per-head LayerNorm. However, global normalization should underperform per-head variants, as one head’s attention pattern affects another’s scale.

To test this, the same pretrained L1 checkpoint is finetuned with different normalizations:

norm_comparison.png

In practice, per-head RMSNorm and per-head LayerNorm are indistinguishable (both 5.4 at 65K), while global LayerNorm has no advantage over unnormalized attention.

Removing position embeddings tests whether variance is the bottleneck.
Without PE, unnormalized attention collapses at long context while per-head LayerNorm stays close to the baseline with PE. So, the causal mask implicitly encodes position, but normalization reveals it.

nope.png


Code