Queryloop

Research

MHA vs MQA vs GQA vs MLA

Zain ul Abideen
July 13, 2024
17 min read

Comparison of Deepseek's new Multi-latent head attention with MHA, MQA, and GQA.

In Transformer decoders, since the attention of tokens is dependent on the preceding tokens, so instead of recalculating the previous context, its Keys and Values are cached. This can significantly speed up the inference but may impose expensive memory overhead as the sequence length and the model dimensions grow. In this context, multiple attention mechanisms have been introduced: Multi-Head Attention, Multi-Query Attention, Grouped-Query Attention, and Multi-Head Latent Attention.

Blog image

Multi-Head Attention — (MHA)

The standard Multi-Head Attention (MHA) computes the query, key, and value matrices for each attention head. During inference, all keys and values are cached for faster inference but this heavy KV-cache is a large bottleneck that can limit the max sequence length and batch size.
Blog image 1
Blog image 2

Multi-Query Attention — (MQA)

To reduce the KV-cache bottleneck in MHA, Shazeer, 2019 introduced Multi-Query Attention (MQA) where the keys and values are shared across all of the different attention heads i.e. it's identical to MHA except that the different heads share a single set of keys, and values. This requires a very light KV cache and hence drastically speeds up decoder inference. However, MQA leads to quality degradation and training instability.
Blog image

Grouped Query Attention — (GQA)

Grouped Query Attention (GQA) is an interpolation between MHA and MQA by introducing a number of subgroups of query heads (less than a total number of attention heads) with each having a single Key and Value head. In contrast to MQA, GQA keeps the same proportional decrease in memory bandwidth and capacity as model size increases. An intermediate number of subgroups leads to an interpolated model that is of higher quality than MQA but faster than MHA. It's evident to notice that GQA with a single group is equivalent to MQA.
Blog image

Implementing MHA, MQA, and GQA

This Attention class dynamically implements all three attention mechanisms based on self.num_kv_heads and self.num_heads.
1class Attention(nn.Module):
2 def __init__(self, model_args: MOEConfig):
3 super().__init__()
4 d_model = model_args.d_model
5 self.num_heads = model_args.num_heads
6 self.head_dim = model_args.d_model // model_args.num_heads
7 self.num_kv_heads = (
8 model_args.num_heads if model_args.num_kv_heads == 0 else model_args.num_kv_heads
9 )
10 assert self.num_heads % self.num_kv_heads == 0
11 self.num_queries_per_kv = self.num_heads // self.num_kv_heads
12
13 self.key = nn.Linear(d_model, self.head_dim * self.num_heads)
14 self.query = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
15 self.value = nn.Linear(d_model, self.head_dim * self.num_kv_heads)
16 self.proj = nn.Linear(d_model, d_model, model_args.bias)
17 self.attn_dropout = nn.Dropout(model_args.dropout)
18 self.res_dropout = nn.Dropout(model_args.dropout)
19 self.flash_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
20 def forward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) -> torch.Tensor:
21 batch, seq_len, d_model = x.shape
22 k: torch.Tensor
23 q: torch.Tensor
24 v: torch.Tensor
25 k = self.key(x)
26 q = self.query(x)
27 v = self.value(x)
28 k = k.view(
29 batch, seq_len, -1 , self.head_dim
30 )
31 q = q.view(batch, seq_len, -1, self.head_dim)
32 v = v.view(batch, seq_len, -1, self.head_dim)
33 print(q.shape)
34 print(k.shape)
35 q, k = apply_rope(q, k, freqs_cis)
36 if self.num_kv_heads != self.num_heads:
37 k = torch.repeat_interleave(k, self.num_queries_per_kv, dim=2)
38 v = torch.repeat_interleave(v, self.num_queries_per_kv, dim=2)
39 k = k.transpose(1, 2)
40 q = q.transpose(1, 2)
41 v = v.transpose(1, 2)
42 print("q.shape",q.shape)
43 print("k.shape",k.shape)
44
45 output = F.scaled_dot_product_attention(
46 q,
47 k,
48 v,
49 attn_mask=None,
50 dropout_p=self.attn_dropout.p if self.training else 0.0,
51 is_causal=True,
52 )
53 output = output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
54 output = self.proj(output)
55 output = self.res_dropout(output)
56 return output

Multi-Head Latent Attention — (MLA)

Multi-Head Latent Attention (MLA) achieves superior performance than MHA, as well as significantly reduces KV-cache boosting inference efficiency. Instead of reducing KV-heads as in MQA and GQA, MLA jointly compresses the Key and Value into a latent vector.
Blog image

Low-Rank Key-Value Joint Compression

Instead of caching both the Key and Value matrices, MLA jointly compresses them in a low-rank vector which allows caching fewer items since the compression dimension is much less compared to the output projection matrix dimension in MHA.
Blog image

Implementing MLA

The following Attention class implements MLA.
1class Attention(nn.Module):
2 def __init__(self, model_args: MOEConfig):
3 super().__init__()
4 d_model = model_args.d_model
5 self.num_heads = model_args.num_heads
6 self.head_dim = model_args.d_model // model_args.num_heads
7 self.attn_dropout = nn.Dropout(model_args.dropout)
8 self.res_dropout = nn.Dropout(model_args.dropout)
9 self.flash_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
10
11 self.q_lora_rank = model_args.q_lora_rank
12 self.qk_rope_head_dim = model_args.qk_rope_head_dim
13 self.kv_lora_rank = model_args.kv_lora_rank
14 self.v_head_dim = model_args.v_head_dim
15 self.qk_nope_head_dim = model_args.qk_nope_head_dim
16 self.q_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim
17 self.q_a_proj = nn.Linear(d_model, model_args.q_lora_rank, bias=False)
18 self.q_a_layernorm = RMSNorm(model_args.q_lora_rank)
19 self.q_b_proj = nn.Linear(model_args.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
20 self.kv_a_proj_with_mqa = nn.Linear(d_model,model_args.kv_lora_rank + model_args.qk_rope_head_dim,bias=False,)
21 self.kv_a_layernorm = RMSNorm(model_args.kv_lora_rank)
22 self.kv_b_proj = nn.Linear(model_args.kv_lora_rank,self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),bias=False,)
23 self.o_proj = nn.Linear(self.num_heads * self.v_head_dim,d_model, bias=False,)
24 def forward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) -> torch.Tensor:
25 batch, seq_len, d_model = x.shape
26 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
27 q = q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
28 q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
29 compressed_kv = self.kv_a_proj_with_mqa(x)
30 compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
31 k_pe = k_pe.view(batch, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
32 kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
33 .view(batch, seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
34 .transpose(1, 2))
35 k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
36 kv_seq_len = value_states.shape[-2]
37 q_pe, k_pe = apply_rope(q_pe, k_pe, freqs_cis)
38 k_pe = k_pe.transpose(2, 1)
39 q_pe = q_pe.transpose(2, 1)
40 query_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
41 query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
42 query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
43 key_states = k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
44 key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
45 key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
46 attn_mtx = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
47 attn_mtx = attn_mtx + mask[:, :, :seq_len, :seq_len]
48 attn_mtx = F.softmax(attn_mtx.float(), dim=-1).type_as(key_states)
49 attn_mtx = self.attn_dropout(attn_mtx)
50 output = torch.matmul(attn_mtx, value_states) # (batch, n_head, seq_len, head_dim)
51 output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads * self.v_head_dim)
52 output = self.o_proj(output)
53 output = self.res_dropout(output)
54 return output
Blog image 1
Blog image 2

Conclusion

MHA can be faster for inference but its KV-cache overheads make it impossible to scale to larger-sized models. MQA significantly reduces KV-cache but degrades in quality as the model size increases. GQA is a balance between both attention mechanisms in terms of KV-caching and memory bandwidths. MLA requires a significantly lower KV cache yet outperforms MHA in output quality.
Blog image
Where n_h is the number of heads, d_h is the dimension per head, l is the number of layers, n_g is the number of subgroups in GQA and d_c is the compression dimension.
AILlmNLPMachine LearningDeep Learning