Under the Hood of Llama 3.1 70B Distributed Inference

Subrata Goswami
12 min readSep 3, 2024

--

The following are some notes on how the Llama 3.1 70B model works in distributed environment. Will focus only on the pure PyTorch implementation which uses the fairscale library. Fairscale library is a fork of the Nvidia Megatron-LM library. The Llama 3 and 3.1 models are the same, hence in the following will refer to them as Llama 3.

Transformer and Attention Recap:

The Transformer at block diagram level, show below, consists of 2 parts — Attention and Feed Forward Network ( FFN) or Multi Layer Perceptron ( MLP)

Fig 1 : Transformer architecture ( from Megatron-LM paper)

Attention function is a mapping of a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of value.

Attention(Q, K, V ) = softmax(Q.transpose(K) /sqrt(dk) ).V

To counteract vanishing gradient effect, the dot product is scaled down by sqrt(dk) .

Fig 2: Attention (from Attention Is All You Need)

In multi-head attention, instead of performing a single attention function with dim-dimensional keys, values and queries, each head linearly projects the queries, keys and values with different linear projections (learned) to dk, dk and dv dimensions respectively. On each of these projected queries, keys and values, attention is performed in parallel, yielding dv-dimensional output values. The h values are concatenated and once again projected for the final output with dimension dmodel .

MultiHead(Q, K, V ) = Concat(head1, …, headh).WO

headi = Attention(Q.WQi , K.WKi , V.WVi )

The projection matrices shapes are as follows:

WQi: dim×dk , WKi:dim×dk , WVi:dim×dv WO:hdv×dim

The dimension of sub-head key, query and value dimensions are related as follows:

dk = dv = dim/h = dhead

The benefit of multi-head attention is that it allows the model to jointly attend to information from different representation subspaces at different positions ( because of the order of the concat ). This split is not possible with single attention head, as dot product removes any positional dependency.

The second part of the Transformer, the Feed Forward Network, is composed of 2 linear layers with dim * dff and dff*dmodel shapes.

Llama 3 Model — Single Process:

The top level Llama 3 reference implementation code implements generation, chat_completion , text_completion, as shown in the following figure. Both chat_completion and text_completion methods provide tokenization code and returns objects ChatPrediction and CompletionPrediction respectively. Both objects use the same tokenizer, but ChatPrediction uses a ChatFormatter object on top of the tokenizer. Both methods calls generation. The generation method calls the model’s forward method.

Fig 3: Top level Llama 3 code.

In the code Llama3 70B model is called Transformer, although it is more than Transformer. It is composed of 3 parts — VocabParallelEmbedding , TransformerBlock, and a final normalization+linear layer. TransformerBlock appears to be the commonly considered Transformer.

Transformer(                                     
(tok_embeddings): VocabParallelEmbedding()
(layers): ModuleList(
(0-79): 80 x TransformerBlock(
(attention): Attention(
(wq): ColumnParallelLinear()
(wk): ColumnParallelLinear()
(wv): ColumnParallelLinear()
(wo): RowParallelLinear()
)
(feed_forward): FeedForward(
(w1): ColumnParallelLinear()
(w2): RowParallelLinear()
(w3): ColumnParallelLinear()
)
(attention_norm): RMSNorm()
(ffn_norm): RMSNorm()
)
)
(norm): RMSNorm()
(output): ColumnParallelLinear()
)

First, strings are tokenized in an application , say example_chat_completion, by a tokenizer and then embedded into a latent space by VocabParallelEmbedding.

  # VocabParallelEmbedding
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output

Then the embedded values enter the TransformerBlock. TransformerBlock is composed of Attention and MLP/Feed-Forward operations. In case of a single GPU or non-distributed execution, model_parallel_size is 1. Which makes local and global sizes same.


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()

The input is projected and processed parallelly by each head. The shapes of x, wq, wk, wv, xq, xk, and xv are shown below.

# the shapes in the comment are for the 70B model. 
x : (bsz, seqlen, args.dim) # (1,30,4096), for the specific input used

wq, : (args.dim, args.n_heads * self.head_dim ) # (512, 4096) per GPU
wk, wv : (args.dim, self.n_kv_heads * self.head_dim) # (64, 4096)
wo : (elf.n_heads * self.head_dim, args.dim,) # (4096, 512)

xq : ( bsz, seqlen, args.n_heads * self.head_dim ) # (1, 30, 512)
xk, xv : ( bsz, seqlen, args.n_kv_heads * self.head_dim ) # ( 1, 30 , 64)

xq, xk, and xv are further reshaped to have an additional dimension for local GPU. Then Rotary Embedding is applied to xq and xk for keeping positional information.

    # Attention
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Llama 3 uses Group Query Attention ( GQA) , where one KV head feeds multiple query heads .

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)

Hence they become of shapes as shown below. xq’s 3rd dimension is 8x compared to xk’s and xv’s reflecting 1 KV head feeding to 8 Query heads.

xq     : (bsz, seqlen, args.n_local_heads, self.head_dim) # ( 1, 30, 8, 64) 
xk, xv : (bsz, seqlen, args.n_local_kv_heads, self.head_dim ) # (1, 30, 1, 64)

keys, values : (bsz, args.n_local_kv_heads, seqlen+=1, self.head_dim ) # (1, 8, 98, 64)

Then the KV cache is populated. The KV cache is a performance optimization for improving inference latency. The xk and xv tensors are expand to keys and values that store all the previous xk and xv’s. This eliminates re-computation, hence lowers latency. During each step, the 3rd dimension, sequence length, increases by 1 for both keys and values, and can grow to the context length of args.max_seq_len .

# Attention
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

Then, attention is applied.

# Attention
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) # concatenation, (bs, seqlen, head_dim*n_local_heads)
return self.wo(output)

After that , MLP/FeedForward block is applied.

    # TransformerBlock
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out

FeedForward is composed of 2 linear and 1 element-wise (*) layers.

class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

A RMSNorm layer follows the TransformerBlock layers. The final ColumnParallelLinear layer does a final linear de-embeds to the token space.

class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size # 128256
self.n_layers = params.n_layers # 80

self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)

self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h) # (bsz,seqlen,4096)
output = self.output(h).float() # sel.output ( 4096, 128256)
return output # (bsz,seqlen,128256)

Llama 3 70B Model Parallel Inference Collectives :

The 70B model uses a 8-GPU Tensor Parallel model. The model code calls it Model Parallel. The Model uses ColumnParallelLinear, RowParallelLinear, and VocabParallelEmbedding. These 3 operations involve collective calls. The following , taken from the source code, explains the 3 distributed compute behaviors.

ColumnParallelLinear sets the weight to out_features divided by the number GPUs. Similarly RowParallelLinear sets the weight to in_features divided by the number GPUs

class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.

This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].

Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))

Llama 3 70B Model Parallel Inference Algebra :

The following figure pictorially shows the Tensor Parallel versions of Attention and FFN. As the Megatron-LM paper shows, in inference only the forward pass is required, hence f is an identity ( or copy) operator and g is an allreduce.

Fig 4: Tensor Parallel Attention and MLP — from the Megatron-LM paper.

The linear algebraic representation of the distributed execution is as follows.

# 2 layer FFN 
Y = nl(XA)
Z = YB

A = {A1 A2] # column wise partition across GPUs
Y = [Y1 Y2] = [ nl(XA1) nl(XA2) ]
B = [ B1 # row wise partition across GPUs
B2]
Z = [ Y1 Y2] [ B1 = [Y1B1+Y2B2] # this requires all reduce
B2 ]

# 1 layer linear
Y = nl(X)A
A = {A1 A2] # column wise partition
Y = [Y1 Y2] = nl(X)[ A1 A2] = [nl(X)A1 nl(X)A2] # this require allgather

# Attention
# ignoring softmax and scaling, they are within each Query head
# scaling is also elementwise
# each GPU runs a number of query heads independently
X = concat(QiKiVi) # concatenation acorss all GPUs
# concat(QiiKijVij) - concatenation within GPU j
Y = XO = concat(QiKiVi)O = [concat(Qi1Ki1Vi1) concat(Qi2Ki2Vi2)][O1 #allreduce
O2]

Llama 3 70B Model Parallel Inference :

Pictorially the Llama 3 Model Parallel execution somewhat looks like following figure.

Fig 5: Llama 3 Model Parallel execution

The input string is tokenized and fed to VocabParallelEmbedding in each of the GPU instances independently. VocabParallelEmbedding parallelizes embedding in the vocabulary dimension. The vocabulary, size 128256, is divided equally among the GPUs, hence 16032. The embedding dimension , params.dim, is 4096. Hence the weight tensor in each GPU is (16032, 4096) . The output from each GPU is allreduced. The pseudo call stack is shown below.

class VocabParallelEmbedding(torch.nn.Module).forward()
reduce_from_model_parallel_region(output_parallel)
return _reduce(ctx, input_)
torch.distributed.all_reduce(input_, group=group)

The embedded tensors then enter the first ( id =0) TransformerBlock . TransformerBlock is the Transformer and is composed of Attention and FeedForward blocks. All the operations are local, except the 2 involving RowParallelLinear — one in attention block and the other in MLP block. They both do allreduce. The pseudo call stack is shown below.

class RowParallelLinear(torch.nn.Module).forward()
reduce_from_model_parallel_region(output_parallel)
return _reduce(ctx, input_)
torch.distributed.all_reduce(input_, group=group)

Another 79 exactly similar TransformerBlock follows.

In Multi-Head Attention, each Attention head resides in a one distinct GPU. The overall Attention output is the dot product between concatenated output of the heads and a linear layer WO ( see above in the Attention Recap section). WO is partitioned among the GPUS. hence, this overall dot product operation includes an allreduce.

The MLP block allreduce is necessary as each GPUi has Xi.Ai.Bi and the full results for MLP is sum( Xi.Ai.Bi) .

The final ColumnParallelLinear layer following the TransformerBlock layers does an allgather . This allgather collects partial de-embedding to token space results in each GPU into one whole. The pseudo call stack is shown below.

class ColumnParallelLinear.forward()
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
class _GatherFromModelParallelRegion(torch.autograd.Function).forward ()
_gather(input_: torch.Tensor) -> torch.Tensor:)
torch.distributed.all_gather(tensor_list, input_, group=group)

Llama 3 Distributed Training :

Llama 3 models were pre-trained with three types of parallelization: data parallelization (DP) , model/tensor parallelization (MP), and pipeline parallelization (PP) .

MP applies collective on activations. DP applies collectives on model parameters. PP involves point-to-point communications between GPUs.

If there is ample memory available, then just DP and MP can be used. Here each of the model parallel GPUs form their own allreduce region at every iteration. Supposing 32 GPUs are divided into 4 by-8 MP regions. Then each GPU from each of the 8-GPU model parallel region forms their own allreduce region. Hence in total, 8 allreduce regions with 4 GPUs each. This combines DP with MP.

For further memory needs, PP can be added. PP partitions layers among a number of GPU. The parallelization hierarchy then is pipeline, model , then data. Supposing the same 32 GPUs are divided into 4 by-4 MP regions with 2-GPU PP. Then each GPU from each of the 4x2-GPU model and pipeline parallel region forms their own allreduce region. Hence again 8 allreduce regions with 4 GPUs each. This combines DP , MP and PP.

Continual Pre-training does a full parameter update with new data auto-regressively from a previous checkpoint. It involves the same machinery as in pre-training from scratch, except for less data.

References:

  1. Llama 2: Open Foundation and Fine-Tuned Chat Models, https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/
  2. Introducing Meta Llama 3: The most capable openly available LLM to date — https://ai.meta.com/blog/meta-llama-3/
  3. Introducing Llama 3.1: Our most capable models to date, https://ai.meta.com/blog/meta-llama-3-1/
  4. (Deprecated) Meta Llama 3, https://github.com/meta-llama/llama3
  5. Llama Models, https://github.com/meta-llama/llama-models
  6. Llama Recipes: Examples to get started using the Llama models from Meta, https://github.com/meta-llama/llama-recipes
  7. fairscale , https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/layers.py
  8. Megatron-LM: Training Multi-Billion Parameter Language Models Using
    Model Parallelism, https://arxiv.org/pdf/1909.08053
  9. class Llama, https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/generation.py#L62
  10. class Transformer(nn.Module), https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L276
  11. Continual Learning with Pretrained Checkpoints, https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/allmodels/continuetraining.html

--

--