People have tweaked the Transformer architecture enough in 7 years that we’re apparently now calling the current recipe “Transformer++”.

The changes between Vaswani et al 2017 and Radford et al 2019 are well-known (see Figure 3 here): removing the encoder, clipping gradients, not using dropout, and changing the activation function to something weighted or gated. So I focus on post-GPT-2 tweaks.

I also focus on architecture changes rather than data changes (curation or generation) or training changes (infrastructure and hyperparameters) or post-training or inference optimisations (e.g. Various kinds of data parallelism and model parallelism across multiple devices arose largely after 2019). I won’t get into multimodal architectures.

Inclusion criterion: three strong open-source architectures from 2024 using the tweak (LLaMA, Gemma, Qwen2.5, DeepSeek-V2, Hunyuan).

(The term “Transformer++” was coined in this sense by Gu and Dao 2023, who also noted several of the main tweaks.)



The Transformer++

Let the “Transformer++” be a Transformer with

  • A fused attention implementation (the scaled dot-product backend -> FlashAttention). Subquadratic memory complexity in input sequence length. Practically: can double GPU utilization and so halve training time. Also enables longer contexts and speeds up inference on long context input.
  • Rotary position embedding (sinusoidal -> learned APE -> RoPE)
  • Removing attention’s redundant key heads and value heads (vanilla MHA -> MQA -> GQA)
  • Regularized / preconditioned optimizer (Adam -> AdamW -> SOAP)
  • Normalise before each layer (post LayerNorm -> pre LayerNorm)
  • When doing layer normalization: just rescale, don’t centre (LayerNorm -> RMSNorm)
  • Divine activation function for the MLP (GeLU -> … -> SwiGLU or GeGLU)
  • Tied embeddings. An oldie but goodie.
  • Fix logit drift (query/key normalization)
  • Fixing that one softmax off-by-one (fixed in some places around 2021)



Less well-established tweaks

  • Sparsification. I could mention the turn to sparse Mixtures of Experts. But this turn was more of a cost-saving thing. LLaMA is still dense though.
  • BPE -> Tiktoken / SentencePiece. Basically the same, but about 25% better compression.
  • Quantization. FP16 to BF16 to int8.
  • No bias on QKV projection or layernorm. Thus not sure about putting the biases back into attention, but various people use it.
  • Sliding Window Attention e.g. Rolling Buffer Cache
  • Cross-Layer Attention shrinks the KV cache
  • WARP
  • Regularizing outputs (“soft-capping logits”)



Occurrence in top open architectures

ComponentTweakLLaMA 3Gemma 2Qwen2.5DeepSeek-V2Hunyuan-Large
       
AttentionAttention kernelFlashAttention-2?Eager attentionFlashAttention-2FlashAttention-2FlashAttention-2
AttentionSliding window attentionNo?Local-Global SWAbothNo?No?
AttentionRemoving KV headsGQAGQAGQAMLAGQA
AttentionCross-Layer AttentionNoNoNoNoCLA
Attentionprefill KV cacheYes??No??
Attentionlow-rank KV cache compressionNoNoNoYesNo
AttentionBiases in QKV projectionNo??QKV biases?No
AttentionQK NormalizationNo?No??No?No
       
Block sequenceParallel layersNoNo?No?No?No?
       
EmbeddingPosition encodingRoPERoPERoPEdecoupled RoPEDynamicNTKRope (*6)
EmbeddingTied embeddings“Shared” (*1)TiedTied (*4)?Tied (*7)
       
OptimizerRegularized / preconditionedAdamWAdamW(*3)?AdamWAdamW
       
Activation normalizationpost or pre layernormpreboth?pre?
Activation normalizationDon’t centerRMSNorm(*2)RMSNormRMSNormRMSNormRMSNorm
       
Output normalizationSoft-capped logitsNo?Soft-capped logitsNo?No?No?
       
Activation functionGated linear unitSwiGLUGeGLUSwiGLUSwiGLUSwiGLU (*5)
       
SparsificationSparse?DenseDense?DenseMoEMoE
       
Weights quantizationBF16 trainingYesNo, FP32YesYesYes
Weights quantization8-bit post-trainingIn one versionNoNoNoIn one version


[1] In the 3.2 models anyway
[2] Llama 2 uses RMSNorm anyway [3] They recommend AdamW for fine-tuning, unsure for training
[4] Only the smaller models
[5] Code says “silu”
[6] “Credits to the Reddit users /u/bloc97 and /u/emozilla”
[7] https://huggingface.co/tencent/Tencent-Hunyuan-Large/blob/main/Hunyuan-A52B-Pretrain/modeling_hunyuan.py#L1419



Caveats

  • The above ignores the much more important changes since 2017 to data “collection” (curation and synthesis), cluster infrastructure, post-training, and scaffolding.
  • The public tokenizers still use byte-pair encoding
  • Some models have absurdly high embedding-parameter counts. Previously we conjectured that this is a PR tradeoff, to allow underreporting the Transformer-parameter count and so enter a lesser model class (“7B”). There is actually a good reason on top of this PR reason: TPUs have dedicated hardware (SparseCores) which accelerate embedding lookups relative to NVIDIA hardware, which lets models trained on them have much larger vocabularies, and hence huge embedding tables. The only fudge factor then is then them using the standard parameter reporting convention (only counting Transformer parameters) despite having a good reason to explode the embedding.
  • A lot of this doesn’t improve absolute performance that much, but it does make it a lot cheaper to run.
  • And this is just the public architecture. And probably there are some public methods which we haven’t realised are improvements yet.



I thank Kushal Thaman for helpful comments.



See also

  • https://arxiv.org/html/2410.16682v1
  • https://openreview.net/forum?id=d8w0pmvXbZ



Bibliography

Maybe

https://kellerjordan.github.io/posts/muon/ https://epochai.substack.com/p/how-has-deepseek-improved-the-transformer



Tags: AI

Leave a comment


Subscribe through RSS , Podcast , Email