Queryloop

Research

Linear Rope vs NTK vs YaRN vs CoPE

Zain ul Abideen
July 13, 2024
17 min read

Comparison of various positional embeddings.

When processing sequences such as text, the ordering information is clearly critical. To incorporate ordering information and rather not treat sequences as sets, encoding position information is vital. Positional encoding achieves this by assigning an embedding vector to each position and adding that to the corresponding token representations. There have been many Positional Encoding techniques introduced: Absolute-PE counts tokens from the start of a sequence, Relative-PE counts backward starting at the current token. We will discuss some of the more advanced position encoding methods: RoPE and its variants (Linear, NTK, YaRN), CoPE.

Blog image

Rotary Position Embedding (RoPE)

Rotary Position Embedding (RoPE) leverages the positional information into the training process of pre-trained models by encoding the absolute position with a rotation matrix and meanwhile incorporating the explicit relative position dependency in the self-attention formulation. The key idea is to encode relative position by multiplying the context representations with a rotation matrix. Also, RoPE decays with the relative distance increased.
Blog image
The intuition behind RoPE is: to simply rotate the affine-transformed word embedding vector by the amount of angle multiples of its position index to incorporate the relative position embedding.
Blog image

RoPE Implementation

Now let's implement the RoPE formulation:
1def apply_rope(k, q, cis):
2 # Idea suppose vector v = [x,y,x1,y1,...] # v.shape = dim
3 # convert vetor into complex num # ie two vec one real, one imagery
4 # [x,y,x1,y1,...] -> x+iy, x1+iy1
5 # Multiplying by complex num == roatate vector
6 # => (x + iy) * (cos + isin) -> x'+iy'
7 # restack
8 # x'+iy' -> [x',y',x1',y1'...]
9 # you roated vector in chunks of two lfg!!!
10 _, seq_len, _, _ = q.shape
11 freqs_cos, freqs_sin = cis
12 freqs_cos, freqs_sin = freqs_cos[:seq_len], freqs_sin[:seq_len]
13 # rehsape a shape (...,n )-> (..., n//2,2)
14 q_cis = q.float().reshape(
15 q.shape[:-1] + (-1, 2)
16 ) # (B,T,nhead,C) -> (B,T,nhead,Cc,2) # Cc = C//2
17 k_cis = k.float().reshape(k.shape[:-1] + (-1, 2)) # (B,T,nhead,C) -> (B,T,nhead,Cc,2)
18 xq_r, xq_i = q_cis.unbind(-1) # (B,T,nhead,Cc,2) -> ((B,T,Cc), (B,T,Cc)) split into two tuple
19 xk_r, xk_i = k_cis.unbind(-1) # (B,T,nhead,Cc,2) -> ((B,T,Cc), (B,T,Cc))
20 freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) # freqs.shape = (1,T,1,Cc)
21 freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
22 xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin # (ac-bd) # shape = # (B,T,nhead,Cc)
23 xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos # (ad+bc) * i
24 xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin # (ac-bd)
25 xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos # (ad+bc) * i
26 # now we stack r,i -> [r,i,r2,i2]
27 xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1) # (B,T,nhead,Cc,2)
28 xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1) # (B,T,nhead,Cc,2)
29 # flatten last two dimensions
30 xq_out = xq_out.flatten(3) # (B,T,nhead,C)
31 xk_out = xk_out.flatten(3) # (B,T,nhead,C)
32 return xq_out.type_as(q), xk_out.type_as(q)
Blog image

RoPE Variants

Based on the rotation matrix/rotation angles and how they precompute the cos and sin frequencies, there are three variants of RoPE. Furthermore, to extend the model's context length beyond the pretrained limit, some method-dependent functions are introduced as discussed below.

Linear RoPE

In Linear, they modify the above RoPE equation by introducing the following method dependent functions g(m) and h(θ_d):
1def precompute_freqs_cis_linear(dim: int, end: int, theta: float = 10000.0):
2 freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
3 t = torch.arange(end, device=freqs.device)
4 freqs = torch.outer(t, freqs).float() # gives diffrent angle vector
5 freqs_cos = torch.cos(freqs) # real
6 freqs_sin = torch.sin(freqs) # imaginary
7 return freqs_cos, freqs_sin

Neural Tangent Kernel (NTK)

NTK-aware interpolation resolves the problem of losing high-frequency information when interpolating the RoPE embeddings by scaling high frequencies less and low frequencies more, instead of scaling every dimension of RoPE equally by a factor s. This is done simply by performing a base change on the value of θ.
1def precompute_freqs_cis_ntk(dim: int, end: int, theta: float = 10000.0, alpha: int= 16):
2 theta = theta * alpha ** (dim / (dim-2))
3 freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
4 t = torch.arange(end, device=freqs.device)
5 freqs = torch.outer(t, freqs).float()
6 freqs_cos = torch.cos(freqs) # real
7 freqs_sin = torch.sin(freqs) # imaginary
8 return freqs_cos, freqs_sin

Yet Another RoPE Extension (YaRN)

YaRN introduces a ramp function which is further incorporated in method dependent functions.
Blog image 1
Blog image 2
1def precompute_freqs_cis_yarn(dim: int, original_max_position_embeddings: int, theta: float = 10000.0, scale: int = 16, beta_fast:int=32, beta_slow:int=1, mscale: float=0.707, max_position_embeddings: int=2048):
2 pos_freqs = theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
3 inv_freq_extrapolation = 1.0 / pos_freqs
4 inv_freq_interpolation = 1.0 / (scale * pos_freqs)
5 low = max(math.floor(dim * math.log(original_max_position_embeddings/(beta_fast * 2 * math.pi)))/(2 * math.log(theta)),0)
6 high = min(math.ceil(dim * math.log(original_max_position_embeddings/(beta_slow * 2 * math.pi)))/(2 * math.log(theta)),dim-1)
7 linear_func = (torch.arange(dim//2, dtype=torch.float32) - low) / (high - low)
8 ramp_func = torch.clamp(linear_func, 0, 1).float().to(device=pos_freqs.device)
9 inv_freq_mask = 1 - ramp_func
10 inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
11 _mscale = float((0.1 * math.log(scale) + 1.0) * mscale)
12 t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
13 freqs = torch.outer(t, inv_freq)
14 dtype = torch.get_default_dtype()
15 freqs_cos = freqs.cos() * _mscale
16 freqs_sin = freqs.sin() * _mscale
17 return freqs_cos, freqs_sin

Contextual Position Encoding (CoPE)

Contextual Position Encoding (CoPE) is context-dependent PE allowing positions to be conditioned on the context of previous tokens and incrementing position on only certain tokens (for example, period in a sentence), thus enabling counting of different levels of position abstractions such as words or sentences.

CoPE Implementation

Here's how the CoPE method works:
1class CoPE(nn.Module):
2 def __init__(self, npos_max, head_dim):
3 super().__init__()
4 self.npos_max = npos_max
5 self.pos_emb = nn.Parameter(torch.zeros(1, head_dim, npos_max))
6
7 def forward(self, query, attn_logits):
8 gates = torch.sigmoid(attn_logits)
9 pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
10 pos = pos.clamp(max=self.npos_max - 1)
11 pos_ceil = pos.ceil().long()
12 pos_floor = pos.floor().long()
13 logits_int = torch.matmul(query, self.pos_emb)
14 logits_ceil = logits_int.gather(-1, pos_ceil)
15 logits_floor = logits_int.gather(-1, pos_floor)
16 w = pos - pos_floor
17 return logits_ceil * w + logits_floor * (1 - w)
Blog image

CoPE Gating Mechanism

Gating decides which tokens to include for counting positions using their context vectors and computes a gate value for every query-key pair. A gate value of 1 means the token is considered in the position counting, while a value of 0 means it is ignored.

CoPE - Computing Positions

To compute positions, the gate values between the current token and all tokens preceding it are added. Each position may represent a token/word/sentence number in the given sequence. To calculate limited positions i.e. if gates are sparsely activated (when counting sentences), we can cover the whole context of sequence length T with fewer positions and clamp each position within the maximum possible position. It's evident position values are float values since they add sigmoid outputs [0,1] and each ith-position value is a float within [0, i]. Thus the positions are not learnable and cannot be computed by an embedding layer.

CoPE - Interpolating Position Embeddings

To overcome the limitation of learning an embedding layer due to the floating position values, a learnable position embedding e[p] is assigned to each integer position in the sequence, then the positional embedding of the ij-th element will be a simple interpolation between two closest integer embeddings weighted by the fractional position values computed above. Finally, the attention can be computed by adding the position embeddings in the key vectors.
Blog image 1
Blog image 2

Implementing CoPE

To save memory and compute, q.e[p] matrices are precomputed for further interpolation and then finally added into the context. The CoPE class can compute the interpolated values as follows:
1class CoPE(nn.Module):
2 def __init__(self, npos_max, head_dim):
3 super().__init__()
4 self.npos_max = npos_max
5 self.pos_emb = nn.Parameter(torch.zeros(1, head_dim, npos_max))
6
7 def forward(self, query, attn_logits):
8 gates = torch.sigmoid(attn_logits)
9 pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
10 pos = pos.clamp(max=self.npos_max - 1)
11 pos_ceil = pos.ceil().long()
12 pos_floor = pos.floor().long()
13 logits_int = torch.matmul(query, self.pos_emb)
14 logits_ceil = logits_int.gather(-1, pos_ceil)
15 logits_floor = logits_int.gather(-1, pos_floor)
16 w = pos - pos_floor
17 return logits_ceil * w + logits_floor * (1 - w)

CoPE in Attention Class

The CoPE embeddings are added to the context as follows: attn_mtx += self.cope(q, attn_mtx) is where CoPE embeddings are added to the context.
My rough notes regarding implementation details of YaRN:
Blog image 1
Blog image 2
LlmNLPMachine LearningDeep Learning