Pre-training Mini Versions of LLMs — GPT and Llama3
This blog goes over how to pre-train small versions of the leading open source Large Language Models (LLM). Here 3 models are covered — 2 versions of GPT2 and one Llama3. The code along with the preprocessed/tokenized data is available at — https://github.com/whatdhack/mini_llms . The pre-training and generation is pretty straight forward requiring no special arguments other than — train for training.
The first few sections go over the Transformer architecture and its evolution to the current GPT2 form. The following sections go over the mini LLM models in details. The final section shows validation loss trend chart and sample generation outputs.
Transformer ( June 2017):
The following pictures from the Attention Is All You Need (2017) paper show the architecture of the full Transformer model, It includes both the encoder and decoder part. The Transformer addressed the sequential bottleneck of RNN’s used uptill then for sequence modelling (auto-completion, sequence classification, … ) and transduction ( translation, summarization, image captioning, speech-to-text,..).
Most competitive neural sequence transduction models of circa-2017 had an encoder-decoder structure. The encoder maps an input sequence of symbol representations (x1, …, xn) to a sequence of continuous representations z = (z1, …, zn). The decoder then generates an output sequence (y1, …, ym) of symbols from z one element at a time. At each step the model is auto-regressive, consuming the previously generated symbols as additional input when generating the next.”
Each encoder layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. Residual connection is employed around each of the two sub-layers, followed by layer normalization. The output of each sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. Residual connection improves accuracy as shown in the paper Deep Residual Learning for Image Recognition. The other issue of vanishing/exploding gradient of deeper network is addressed by normalized initialization and intermediate normalization layers such as LayerNorm . To facilitate residual connections, all sub-layers and embedding layers, produce outputs of dimension dmodel = 512.
In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. As in the encoder, residual connections around each of the sub-layers is employed, followed by layer normalization. The self-attention sub-layer in the decoder stack employs masking to prevent positions from attending to subsequent positions. The masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.
Attention function is mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of value.
Attention(Q, K, V ) = softmax(Q.transpose(K) /sqrt(dk) ).V
The authors suspected that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this vanishing gradient effect, the authors scaled down the dot products by sqrt(dk) . The authors offered the following explanation. Consider q and k are two independent random variables with mean 0 and variance 1. Then their dot product, q · k = sum i ( qi.ki) has mean 0 and variance dk.
By using multi-head attention, instead of performing a single attention function with dmodel-dimensional keys, values and queries, the authors linearly project the queries, keys and values h times with different linear projections (learned) to dk, dk and dv dimensions, respectively. On each of these projected queries, keys and values, attention is performed in parallel, yielding dv-dimensional output values. The h values are concatenated and once again projected for he final values with dimension dmodel .
MultiHead(Q, K, V ) = Concat(head1, …, headh).WO
headi = Attention(Q.WQi , K.WKi , V.WVi )
The projection matrices shapes are as follows:
WQi: dmodel×dk , WKi:dmodel×dk , WVi:dmodel×dv WO:hdv×dmodel
The dimension of sub-head key, query and value dimensions are related as follows:
dk = dv = dmodel/h = dhead
The benefit of multi-head attention is that it allows the model to jointly attend to information from different representation subspaces at different positions ( because of the order of the concat ). This split if not possible with single attention head as averaging inhibits the position dependency.
The Feed Forward network is composed of 2 linear layers with dmodel * dff and dff*dmodel shapes.
Transformer-Decoder (T-D) (January 2018):
In the paper Generating Wikipedia by Summarizing Long Sequences (2018), the authors simplified the Transformer architecture by dropping the encoder. This resulted in 50% reduction in size of the network. To accommodate the new architecture, they also needed to rearrange the output part. For a sequence-transduction example (m1 , …, mn) → (y1 , …, yη ), they concatenated input and output into the sentence (w1 , …, wn+η+1) = (m1 , …, mn, δ, y1 , …, yη ), where δ is a special separator token. Then they trained the model to predict the next word given the previous ones.
Generative Pre-Training (GPT 1 and 2 ) (June 2018, February 2019) :
The GPT1 architecture was introduced by the authors of paper Improving Language Understanding by Generative Pre-Training (2018) . They used the T-D architecture described above. They trained GPT1 model in 2 steps — unsupervised pre-training and supervised fine tuning. The model applies a multi-headed self-attention operation over the input context tokens followed by position-wise feed forward layers to produce an output distribution over target tokens.
h0 = U.We + Wp
hi = Transformer_Decoder_block(h(i−1) )∀i ∈ [1, n]
P(u) = softmax(hn.transpose(We))
U = (u −k , . . . , u −1 ) is the context vector of tokens
n is the number of layers
We is the token embedding matrix
Wp is the position embedding matrix
They used SGD optimizer during training to maximize the likelihood of the next token over a context length of k as per the equation below.
L1(U) = sum ( log P ( ui |u(i−k) , . . . , u(i−1) ; Θ) )
The GPT2 architecture and model was published by OpenAI in February 2019. The architecture was essentially same as GPT1 except for the following differences.
- Layer normalization was moved to the input of each sub-block (similar to a pre-activation residual network in ResNet)
- an additional layer normalization was added after the final self-attention block.
Which results in the following block-diagram view of the architecture.
In addition, a modified initialization which accounts for the accumulation on the residual path with model depth is used. The authors scaled the weights of residual layers at initialization by a factor of 1/ √ N where N is the number of residual layers. The vocabulary is expanded to 50,257. They also increase the context size from 512 to 1024 tokens and a larger batch size of 512 is used.
The total number of parameters can the estimated as sum of embedding, projection, MLP and final linear layers as follows.
embedding parameters = vocab_size*dmodel + sequence_length*dmodel
projection parameters
= nlayers*nhead*( size(WQi)+size(WKi)+size(WVi)+size(WO))
= nlayers * nhead * (4*dmodel*dhead)
FFN/MLP parameters = dmodel*dff + dff*dmodel
Final linear = vocab_size*dmodel
Large Language Model Meta AI (LLama) 1 , 2 and 3 (February 2023, July 2023, and April 2024) :
Meta introduced the 3 models in rapid clip in 2023 and 2024.
To improve the training stability of Llama1 , the authors normalize the input of each transformer sub-layer, instead of normalizing the output. They used the RMSNorm normalizing function, it being computationally less intensive than LayerNorm . They replaced ReLU with SwiGLU. They removed the absolute positional embeddings, and instead, added rotary positional embeddings (RoPE). Used AdamW optimizer with a cosine learning rate schedule.
Llama 2, is an updated version of Llama 1, trained on a new mix of publicly available data. The developers increased the size of the pre-training corpus by 40%, doubled the context length of the model ( from 2048 tokens to 4096 tokens) , and adopted grouped-query attention to improve inference scalability for larger models. Uses the same tokenizer as Llama 1 that employs a byte-pair encoding (BPE) algorithm.
Compared to Llama 2, Llama 3 has a few key improvements. Llama 3 uses a tokenizer with a vocabulary of 128K tokens that encodes language much more efficiently, which leads to substantially improved model performance. To improve the inference efficiency of Llama 3 models, it uses grouped query attention (GQA) across all the model sizes. Llama 3 models were trained on sequences of 8,192 tokens, using a mask to ensure self-attention does not cross document boundaries.
Little Versions of LLMs:
The little versions follow the smallest NanoGPT setup with about 10M parameters. These smallmodels can be easily pre-trained on a consumer grade GPU. The dataset is the one described in the nanoGPT repo as shakespeare_char. The prepared data, train.bin and val.bin, are in the data directory of the repository.
Each model is fully contained in 1 file. The tokenizer used is the same 65 token GPT2 one from tiktoken. The embedding part of the nano_gpt and torch_gpt are exactly same. mini_llama3 uses the Llama VocabParallelEmbedding .
nano_gpt:
The nano_gpt model is a single file version of the NanoGPT model. The parameter list below shows the shapes and sizes of each of the trainable parameters. The model list just shows what layers are there in the model, they do not show the data flow through the model. Dataflow of the model is coded in forward() call
transformer.wte.weight torch.Size([65, 384]) 24960
transformer.wpe.weight torch.Size([256, 384]) 98304
transformer.h.0.ln_1.weight torch.Size([384]) 384
transformer.h.0.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.0.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.0.ln_2.weight torch.Size([384]) 384
transformer.h.0.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.0.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.h.1.ln_1.weight torch.Size([384]) 384
transformer.h.1.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.1.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.1.ln_2.weight torch.Size([384]) 384
transformer.h.1.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.1.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.h.2.ln_1.weight torch.Size([384]) 384
transformer.h.2.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.2.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.2.ln_2.weight torch.Size([384]) 384
transformer.h.2.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.2.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.h.3.ln_1.weight torch.Size([384]) 384
transformer.h.3.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.3.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.3.ln_2.weight torch.Size([384]) 384
transformer.h.3.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.3.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.h.4.ln_1.weight torch.Size([384]) 384
transformer.h.4.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.4.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.4.ln_2.weight torch.Size([384]) 384
transformer.h.4.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.4.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.h.5.ln_1.weight torch.Size([384]) 384
transformer.h.5.attn.c_attn.weight torch.Size([1152, 384]) 442368
transformer.h.5.attn.c_proj.weight torch.Size([384, 384]) 147456
transformer.h.5.ln_2.weight torch.Size([384]) 384
transformer.h.5.mlp.c_fc.weight torch.Size([1536, 384]) 589824
transformer.h.5.mlp.c_proj.weight torch.Size([384, 1536]) 589824
transformer.ln_f.weight torch.Size([384]) 384
number of parameters: 10,646,784
modelnano : NanoGPT(
(transformer): ModuleDict(
(wte): Embedding(65, 384)
(wpe): Embedding(256, 384)
(drop): Dropout(p=0.2, inplace=False)
(h): ModuleList(
(0-5): 6 x Block(
(ln_1): LayerNorm()
(attn): CausalSelfAttention(
(c_attn): Linear(in_features=384, out_features=1152, bias=False)
(c_proj): Linear(in_features=384, out_features=384, bias=False)
(attn_dropout): Dropout(p=0.2, inplace=False)
(resid_dropout): Dropout(p=0.2, inplace=False)
)
(ln_2): LayerNorm()
(mlp): MLP(
(c_fc): Linear(in_features=384, out_features=1536, bias=False)
(gelu): GELU(approximate='none')
(c_proj): Linear(in_features=1536, out_features=384, bias=False)
(dropout): Dropout(p=0.2, inplace=False)
)
)
)
(ln_f): LayerNorm()
)
(lm_head): Linear(in_features=384, out_features=65, bias=False)
)
The forward path of of nano_gpt looks as follows.
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
# block's forward call
#x = x + self.attn(self.ln_1(x))
# attn's forward call
#B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
#q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
#k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
#q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
#v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
#if self.flash:
# efficient attention using Flash Attention CUDA kernels
#y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
#else:
# manual implementation of attention
#att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
#att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
#att = F.softmax(att, dim=-1)
#att = self.attn_dropout(att)
#y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
#y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
#y = self.resid_dropout(self.c_proj(y))
#x = x + self.mlp(self.ln_2(x))
# mlp's fprward call
#x = self.c_fc(x)
#x = self.gelu(x)
#x = self.c_proj(x)
#x = self.dropout(x)
x = self.transformer.ln_f(x) # last layer norm
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
torch_gpt:
The torch_gpt model is based on PyTorch’s TransformerDecoder and TransformerDecoderLayer class. Class GPTDecoderLayer is derived from TransformerDecoderLayer. The GPTDecoderLayer removes the memory input from the Transformer’s encoder part. In the model tgt_is_causal is also set to True along with a mask. It is similar to the class Block in NanoGPT. The torch_gpt model is as follows
gptnet.wte.weight torch.Size([65, 384]) 24960
gptnet.wpe.weight torch.Size([256, 384]) 98304
gptnet.h.layers.0.norm1.weight torch.Size([384]) 384
gptnet.h.layers.0.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.0.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.0.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.0.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.0.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.h.layers.1.norm1.weight torch.Size([384]) 384
gptnet.h.layers.1.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.1.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.1.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.1.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.1.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.h.layers.2.norm1.weight torch.Size([384]) 384
gptnet.h.layers.2.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.2.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.2.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.2.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.2.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.h.layers.3.norm1.weight torch.Size([384]) 384
gptnet.h.layers.3.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.3.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.3.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.3.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.3.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.h.layers.4.norm1.weight torch.Size([384]) 384
gptnet.h.layers.4.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.4.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.4.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.4.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.4.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.h.layers.5.norm1.weight torch.Size([384]) 384
gptnet.h.layers.5.self_attn.in_proj_weight torch.Size([1152, 384]) 442368
gptnet.h.layers.5.self_attn.out_proj.weight torch.Size([384, 384]) 147456
gptnet.h.layers.5.ff_norm3.weight torch.Size([384]) 384
gptnet.h.layers.5.ff_linear1.weight torch.Size([1536, 384]) 589824
gptnet.h.layers.5.ff_linear2.weight torch.Size([384, 1536]) 589824
gptnet.ln_f.weight torch.Size([384]) 384
number of parameters: 10,646,784
modeltorch : TorchGPT(
(gptnet): ModuleDict(
(wte): Embedding(65, 384)
(wpe): Embedding(256, 384)
(drop): Dropout(p=0.2, inplace=False)
(h): GPTDecoder(
(layers): ModuleList(
(0-5): 6 x GPTDecoderLayer(
(norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=False)
)
(dropout1): Dropout(p=0.2, inplace=False)
(ff_norm3): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(ff_linear1): Linear(in_features=384, out_features=1536, bias=False)
(ff_dropout2): Dropout(p=0.2, inplace=False)
(ff_linear2): Linear(in_features=1536, out_features=384, bias=False)
(ff_dropout3): Dropout(p=0.2, inplace=False)
)
)
)
(ln_f): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(lm_head): Linear(in_features=384, out_features=65, bias=False)
)
)
mini_llama3:
The mini_llama3 parameter and model is listed below. The same tokenizer as in GPT’s is used. The token embedder and linear layers are different — ColumnParallelLinear, RowParallelLinear, and VocabParallelEmbedding . These come from Meta’s fairscale package. In our case xParallelLinear layers behave as simple linear layers as we are not using distributed training ( with FSDP , Megatron etc.)
tok_embeddings.weight torch.Size([65, 384]) 24960
layers.0.attention.wq.weight torch.Size([384, 384]) 147456
layers.0.attention.wk.weight torch.Size([384, 384]) 147456
layers.0.attention.wv.weight torch.Size([384, 384]) 147456
layers.0.attention.wo.weight torch.Size([384, 384]) 147456
layers.0.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.0.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.0.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.0.attention_norm.weight torch.Size([384]) 384
layers.0.ffn_norm.weight torch.Size([384]) 384
layers.1.attention.wq.weight torch.Size([384, 384]) 147456
layers.1.attention.wk.weight torch.Size([384, 384]) 147456
layers.1.attention.wv.weight torch.Size([384, 384]) 147456
layers.1.attention.wo.weight torch.Size([384, 384]) 147456
layers.1.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.1.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.1.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.1.attention_norm.weight torch.Size([384]) 384
layers.1.ffn_norm.weight torch.Size([384]) 384
layers.2.attention.wq.weight torch.Size([384, 384]) 147456
layers.2.attention.wk.weight torch.Size([384, 384]) 147456
layers.2.attention.wv.weight torch.Size([384, 384]) 147456
layers.2.attention.wo.weight torch.Size([384, 384]) 147456
layers.2.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.2.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.2.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.2.attention_norm.weight torch.Size([384]) 384
layers.2.ffn_norm.weight torch.Size([384]) 384
layers.3.attention.wq.weight torch.Size([384, 384]) 147456
layers.3.attention.wk.weight torch.Size([384, 384]) 147456
layers.3.attention.wv.weight torch.Size([384, 384]) 147456
layers.3.attention.wo.weight torch.Size([384, 384]) 147456
layers.3.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.3.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.3.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.3.attention_norm.weight torch.Size([384]) 384
layers.3.ffn_norm.weight torch.Size([384]) 384
layers.4.attention.wq.weight torch.Size([384, 384]) 147456
layers.4.attention.wk.weight torch.Size([384, 384]) 147456
layers.4.attention.wv.weight torch.Size([384, 384]) 147456
layers.4.attention.wo.weight torch.Size([384, 384]) 147456
layers.4.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.4.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.4.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.4.attention_norm.weight torch.Size([384]) 384
layers.4.ffn_norm.weight torch.Size([384]) 384
layers.5.attention.wq.weight torch.Size([384, 384]) 147456
layers.5.attention.wk.weight torch.Size([384, 384]) 147456
layers.5.attention.wv.weight torch.Size([384, 384]) 147456
layers.5.attention.wo.weight torch.Size([384, 384]) 147456
layers.5.feed_forward.w1.weight torch.Size([1408, 384]) 540672
layers.5.feed_forward.w2.weight torch.Size([384, 1408]) 540672
layers.5.feed_forward.w3.weight torch.Size([1408, 384]) 540672
layers.5.attention_norm.weight torch.Size([384]) 384
layers.5.ffn_norm.weight torch.Size([384]) 384
norm.weight torch.Size([384]) 384
output.weight torch.Size([65, 384]) 24960
number of parameters: 13,325,952
modelllama3 : Transformer(
(tok_embeddings): VocabParallelEmbedding()
(layers): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): Attention(
(wq): ColumnParallelLinear()
(wk): ColumnParallelLinear()
(wv): ColumnParallelLinear()
(wo): RowParallelLinear()
)
(feed_forward): FeedForward(
(w1): ColumnParallelLinear()
(w2): RowParallelLinear()
(w3): ColumnParallelLinear()
)
(attention_norm): RMSNorm()
(ffn_norm): RMSNorm()
)
)
(norm): RMSNorm()
(output): ColumnParallelLinear()
)
The traing code essentially is same as in nano_gpt and torch_gpt. The parameter initialization is also similar as the 2 gpt’s except for different nn.Module names. Without this initialization training gets stuck with mostly 0 gradients at a non-optimal minima.
Training the 3 models :
The 3 models are small and the the dataset is also small and hence can be trained together. The code to train all the 3 models at the same time is in the file compare_models.py ( e.g. with the — train option ) . The following chart shows the validation loss versus iteration number for the 3 models. All the 3 models use the same AdamW parameters that were used in the original NanoGPT model. The mini_lama3 model has the fastest convergence, but also appears to overfit rapidly. Parameter search and tuning is expected to resolve that issue and get to better val loss (TBD).
The mini_llama3 model uses bf16 and hence the size of the checkpoint is almost half of the other 2.
The following are example generations from the 3 models.
References:
- Transformer — Attention Is All You Need, https://arxiv.org/pdf/1706.03762
- Deep Residual Learning for Image Recognition, https://arxiv.org/pdf/1512.03385
- T-D, Generating wikipedia by summarizing long sequences. ICLR, 2018, https://arxiv.org/pdf/1801.10198
- GPT-1, Improving Language Understanding by Generative Pre-Training, https://www.cs.princeton.edu/courses/archive/spring20/cos598C/lectures/lec4-pretraining.pdf
- GPT-1, Improving Language Understanding by Generative Pre-Training, https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf
- GPT-1 : https://en.wikipedia.org/wiki/GPT-1
- Generative pre-trained transformer, https://en.wikipedia.org/wiki/Generative_pre-trained_transformer
- GPT2 — Language Models are Unsupervised Multitask Learners, https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
- GPT-2: 1.5B release, https://openai.com/research/gpt-2-1-5b-release
- GPT2, Better language models and their implications, https://openai.com/index/better-language-models/
- Transformer (deep learning architecture), https://en.wikipedia.org/wiki/Transformer_(deep_learning_architecture)
- https://en.wikipedia.org/wiki/LLaMA
- https://llama.meta.com/llama2/
- LLaMA: Open and Efficient Foundation Language Models, https://arxiv.org/pdf/2302.13971
- Llama 2: Open Foundation and Fine-Tuned Chat Models, https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/
- Llama 2: Open Foundation and Fine-Tuned Chat Models, https://scontent-sjc3-1.xx.fbcdn.net/v/t39.2365-6/10000000_662098952474184_2584067087619170692_n.pdf
- Introducing Meta Llama 3: The most capable openly available LLM to date — https://ai.meta.com/blog/meta-llama-3/
- https://github.com/facebookresearch/fairscale
- https://github.com/openai/gpt-2?tab=readme-ov-file
- https://github.com/karpathy/nanoGPT
- https://github.com/karpathy/llama2.c