Research
MHA vs MQA vs GQA vs MLA
Zain ul Abideen
July 13, 2024
17 min read
Comparison of Deepseek's new Multi-latent head attention with MHA, MQA, and GQA.
In Transformer decoders, since the attention of tokens is dependent on the preceding tokens, so instead of recalculating the previous context, its Keys and Values are cached. This can significantly speed up the inference but may impose expensive memory overhead as the sequence length and the model dimensions grow. In this context, multiple attention mechanisms have been introduced: Multi-Head Attention, Multi-Query Attention, Grouped-Query Attention, and Multi-Head Latent Attention.

Multi-Head Attention — (MHA)
The standard Multi-Head Attention (MHA) computes the query, key, and value matrices for each attention head. During inference, all keys and values are cached for faster inference but this heavy KV-cache is a large bottleneck that can limit the max sequence length and batch size.


Multi-Query Attention — (MQA)
To reduce the KV-cache bottleneck in MHA, Shazeer, 2019 introduced Multi-Query Attention (MQA) where the keys and values are shared across all of the different attention heads i.e. it's identical to MHA except that the different heads share a single set of keys, and values. This requires a very light KV cache and hence drastically speeds up decoder inference. However, MQA leads to quality degradation and training instability.

Grouped Query Attention — (GQA)
Grouped Query Attention (GQA) is an interpolation between MHA and MQA by introducing a number of subgroups of query heads (less than a total number of attention heads) with each having a single Key and Value head. In contrast to MQA, GQA keeps the same proportional decrease in memory bandwidth and capacity as model size increases. An intermediate number of subgroups leads to an interpolated model that is of higher quality than MQA but faster than MHA. It's evident to notice that GQA with a single group is equivalent to MQA.

Implementing MHA, MQA, and GQA
This Attention class dynamically implements all three attention mechanisms based on self.num_kv_heads and self.num_heads.
Multi-Head Latent Attention — (MLA)
Multi-Head Latent Attention (MLA) achieves superior performance than MHA, as well as significantly reduces KV-cache boosting inference efficiency. Instead of reducing KV-heads as in MQA and GQA, MLA jointly compresses the Key and Value into a latent vector.

Low-Rank Key-Value Joint Compression
Instead of caching both the Key and Value matrices, MLA jointly compresses them in a low-rank vector which allows caching fewer items since the compression dimension is much less compared to the output projection matrix dimension in MHA.

Implementing MLA
The following Attention class implements MLA.


Conclusion
MHA can be faster for inference but its KV-cache overheads make it impossible to scale to larger-sized models. MQA significantly reduces KV-cache but degrades in quality as the model size increases. GQA is a balance between both attention mechanisms in terms of KV-caching and memory bandwidths. MLA requires a significantly lower KV cache yet outperforms MHA in output quality.

Where n_h is the number of heads, d_h is the dimension per head, l is the number of layers, n_g is the number of subgroups in GQA and d_c is the compression dimension.
AILlmNLPMachine LearningDeep Learning