Lecture 21 – CS 189, Fall 2025
In [3]:
# !pip install -U plotly
In [4]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pickle
In [5]:
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
device = "cuda" # Colab
# device = "mps" # Mac with M1/M2
# device = "cpu" # Local CPU
In [6]:
import plotly.io as pio
pio.renderers.default = "vscode"
Sinusoidal Embeddings
In [7]:
D = 6
n = 16
L = 1000
torch.arange(0, D, 2, dtype=torch.float)
Out[7]:
tensor([0., 2., 4.])
In [8]:
def sinusoidal_pe(N, D, L=1000):
pe = torch.zeros(N, D)
div_term = L ** (2 * torch.arange(0, D, 2, dtype=torch.float) / D)
pe[:, 0::2] = torch.sin(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
pe[:, 1::2] = torch.cos(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
return pe
In [9]:
n = 64; D = 128; L = 10
pe = sinusoidal_pe(n, D, L)
px.imshow(pe.cpu().numpy().T, aspect='auto', color_continuous_scale='RdBu_r',
width = 1100, height=500)
In [10]:
dist = pe @ pe.T
fig = px.imshow(dist.cpu().numpy(), color_continuous_scale='Viridis',
width=700, height=700,
title='Dot Product of Positional Encodings')
fig.show()
px.line(x=np.arange(0, n), y=dist[10].cpu().numpy(), width=800, height=400,
title='Dot Product of Positional Encodings for Position 200')
Tokenization
In [11]:
#!pip install transformers
In [12]:
from transformers import AutoTokenizer
# Load the Qwen tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")
tokenizer.vocab_size
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'. You are not authenticated with the Hugging Face Hub in this notebook. If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).
tokenizer_config.json: 0.00B [00:00, ?B/s]
vocab.json: 0.00B [00:00, ?B/s]
merges.txt: 0.00B [00:00, ?B/s]
tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]
Out[12]:
151643
In [13]:
tokenizer.encode("Hello, how are you?")
Out[13]:
[9707, 11, 1246, 525, 498, 30]
Byte Pair Encoding
The Byte Pair Encoding is fairly simple to implement. We start by splitting each word into its characters, appending a special end-of-word symbol </w> to the end of each word. Then, we repeatedly find the most common adjacent pair of symbols across all words and merge them into a new symbol. This process is repeated for a specified number of merges.
Once we have learned the merges, we can use them to tokenize new words by applying the merges in order until no more merges can be applied.
In [14]:
from typing import Dict, List, Tuple
from collections import Counter
EOW = "</w>"
In [15]:
def build_initial_vocab(corpus: str) -> Dict[str, int]:
"""
Create an initial vocabulary consisting of all single characters
observed in the corpus PLUS the EOW marker. IDs are assigned
deterministically (sorted).
Returns:
vocab: dict mapping token string -> integer id
"""
import string
ALL_ASCII = set(string.ascii_letters + string.digits + string.punctuation)
ALL_ASCII.add(EOW) # include the end-of-word symbol
corpus_set = set(corpus) - set(' \n\t\r') # exclude whitespace
return {tok: i for i, tok in enumerate(corpus_set | ALL_ASCII)}
vocab = build_initial_vocab("CS189 is CS.")
print(vocab)
{'3': 0, 'r': 1, '4': 2, 'Z': 3, 'H': 4, '!': 5, ';': 6, 'W': 7, 'w': 8, 'f': 9, 'M': 10, 'C': 11, '/': 12, '0': 13, ',': 14, 'd': 15, 'L': 16, 'y': 17, '~': 18, '[': 19, '.': 20, 'K': 21, 'G': 22, 'U': 23, '{': 24, 'h': 25, 'm': 26, 'Y': 27, '(': 28, '8': 29, '#': 30, 'B': 31, 'A': 32, 'c': 33, 'N': 34, '7': 35, "'": 36, '6': 37, 'T': 38, 'g': 39, '&': 40, 'F': 41, 'R': 42, 'n': 43, '+': 44, 'q': 45, 's': 46, '1': 47, ']': 48, '-': 49, '2': 50, 'P': 51, '9': 52, 'z': 53, 'a': 54, 'o': 55, '"': 56, 'k': 57, 'b': 58, 'X': 59, '%': 60, 'V': 61, '=': 62, '$': 63, '>': 64, 'v': 65, ':': 66, 'I': 67, '@': 68, '`': 69, '</w>': 70, 'J': 71, 't': 72, 'S': 73, 'e': 74, 'D': 75, '}': 76, '*': 77, 'u': 78, '?': 79, 'j': 80, 'Q': 81, 'i': 82, '<': 83, 'x': 84, 'E': 85, '_': 86, 'l': 87, 'p': 88, 'O': 89, '\\': 90, '^': 91, ')': 92, '5': 93, '|': 94}
In [16]:
def corpus_to_char_seq_with_eow(corpus: str) -> List[str]:
"""
Convert the whole corpus into a single flat sequence of symbols.
Each word contributes its characters followed by the end-of-word marker `</w>`.
Example:
corpus = "low lower"
returns: ['l','o','w','</w>','l','o','w','e','r','</w>']
Why this matters:
- We treat the corpus as *one long list* (not a list of per-word lists),
which is sometimes more convenient for teaching and for demonstrating
the role of `</w>` in preventing merges across word boundaries.
"""
seq: List[str] = []
for word in corpus.split():
seq.extend(list(word))
seq.append(EOW)
return seq
print(corpus_to_char_seq_with_eow("CS189 is great!"))
['C', 'S', '1', '8', '9', '</w>', 'i', 's', '</w>', 'g', 'r', 'e', 'a', 't', '!', '</w>']
In [17]:
def count_pair_frequencies(seq: List[str]) -> Counter:
"""
Count frequencies of adjacent symbol pairs over the *flat* sequence.
Boundary rule:
- We *disallow* pairs that START with `</w>` because that would cross a
word boundary on merge (i.e., merging `</w>` with the next word's first
character). We DO allow pairs that END with `</w>` (e.g., ('w','</w>')),
which forms tokens like 'w</w>' and is standard in BPE.
Returns:
A Counter mapping (left_symbol, right_symbol) -> count.
"""
pair_counts = Counter()
for i in range(len(seq) - 1):
left, right = seq[i], seq[i + 1]
if left.endswith(EOW): # This pair would cross a word boundary; skip it.
continue
pair_counts[(left, right)] += 1
return pair_counts
corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
print(pair_freqs)
Counter({('C', 'S'): 2, ('S', '1'): 1, ('1', '8'): 1, ('8', '9'): 1, ('9', '</w>'): 1, ('i', 's'): 1, ('s', '</w>'): 1, ('S', '.'): 1, ('.', '</w>'): 1})
In [18]:
def merge_pair_in_sequence(seq: List[str], pair: Tuple[str, str]) -> List[str]:
"""
Perform a single merge of the given pair across the flat sequence.
Invariant:
- Never merge if the left symbol ends with `</w>`
(prevents crossing word boundaries).
- Scans left-to-right and uses a simple skip mechanic to avoid overlapping merges.
"""
a, b = pair
merged_token = a + b
new_seq: List[str] = []
i = 0
n = len(seq)
while i < n:
if i < n - 1 and seq[i] == a and seq[i + 1] == b and seq[i] != EOW:
new_seq.append(merged_token)
i += 2 # skip the merged pair
else:
new_seq.append(seq[i])
i += 1
return new_seq
corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
pair, freq = pair_freqs.most_common(1)[0]
print("Merging pair:", pair, "with frequency:", freq)
new_seq = merge_pair_in_sequence(seq, pair)
print(new_seq)
Merging pair: ('C', 'S') with frequency: 2
['CS', '1', '8', '9', '</w>', 'i', 's', '</w>', 'CS', '.', '</w>']
In [19]:
from tqdm import tqdm # pip install tqdm
def learn_bpe_merges(corpus: str, num_merges: int = 1000, min_frequency: int = 2) -> Tuple[List[Tuple[str, str]], dict]:
"""
Learn BPE merge rules from the corpus by repeatedly finding the most frequent
adjacent pair and merging it, subject to the boundary rule.
Args:
corpus: Raw text (spaces separate words).
num_merges: Maximum number of merges to learn.
min_frequency: Stop when the most frequent pair occurs fewer than this.
Returns:
merges: A list of (left_symbol, right_symbol) in the order they were learned.
vocab: Final vocabulary mapping token -> id
"""
seq = corpus_to_char_seq_with_eow(corpus)
merges: List[Tuple[str, str]] = []
vocab = build_initial_vocab(corpus)
next_id = max(vocab.values()) + 1
# Wrap the merge loop in a tqdm progress bar
progress = tqdm(range(num_merges), desc="Learning BPE merges", ncols=80)
for step in progress:
pair_counts = count_pair_frequencies(seq)
if not pair_counts:
progress.set_postfix_str("done (no pairs left)")
break
(best_pair, best_count) = pair_counts.most_common(1)[0]
if best_count < min_frequency:
progress.set_postfix_str(f"stopped (min freq < {min_frequency})")
break
# Merge and update structures
seq = merge_pair_in_sequence(seq, best_pair)
merges.append(best_pair)
new_token = best_pair[0] + best_pair[1]
if new_token not in vocab:
vocab[new_token] = next_id
next_id += 1
# Update the tqdm progress bar info
progress.set_postfix_str(f"merge {best_pair} ({best_count})")
progress.close()
return merges, vocab
corpus = "This is the best CS class. This is CS 189."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
print("Learned merges:", merges)
print("Vocabulary:", vocab)
Learning BPE merges: 7%| | 7/100 [00:00<00:00, 1604.73it/s, stopped (min freq
Learned merges: [('i', 's'), ('is', '</w>'), ('T', 'h'), ('Th', 'is</w>'), ('C', 'S'), ('CS', '</w>'), ('.', '</w>')]
Vocabulary: {'3': 0, 'r': 1, '4': 2, 'Z': 3, 'H': 4, '!': 5, ';': 6, 'W': 7, 'w': 8, 'f': 9, 'M': 10, 'C': 11, '/': 12, '0': 13, ',': 14, 'd': 15, 'L': 16, 'y': 17, '~': 18, '[': 19, '.': 20, 'K': 21, 'G': 22, 'U': 23, '{': 24, 'h': 25, 'm': 26, 'Y': 27, '(': 28, '8': 29, '#': 30, 'B': 31, 'A': 32, 'c': 33, 'N': 34, '7': 35, "'": 36, '6': 37, 'T': 38, 'g': 39, '&': 40, 'F': 41, 'R': 42, 'n': 43, '+': 44, 'q': 45, 's': 46, '1': 47, ']': 48, '-': 49, '2': 50, 'P': 51, '9': 52, 'z': 53, 'a': 54, 'o': 55, '"': 56, 'k': 57, 'b': 58, 'X': 59, '%': 60, 'V': 61, '=': 62, '$': 63, '>': 64, 'v': 65, ':': 66, 'I': 67, '@': 68, '`': 69, '</w>': 70, 't': 71, 'J': 72, 'e': 73, 'S': 74, 'D': 75, '}': 76, '*': 77, 'u': 78, '?': 79, 'j': 80, 'Q': 81, 'i': 82, '<': 83, 'x': 84, 'E': 85, '_': 86, 'l': 87, 'p': 88, 'O': 89, '\\': 90, '^': 91, ')': 92, '5': 93, '|': 94, 'is': 95, 'is</w>': 96, 'Th': 97, 'This</w>': 98, 'CS': 99, 'CS</w>': 100, '.</w>': 101}
In [20]:
def bpe_encode(text: str, merges: List[Tuple[str, str]], vocab: Dict[str, int]) -> List[int]:
"""
Encode a string into token IDs:
1) Convert text -> flat char+EOW sequence
2) Apply learned merges in order
3) Map final tokens to IDs via vocab
Note:
- This simple teaching encoder applies merges globally; it assumes the
learned merges were derived from a similar distribution (your corpus).
- For speed, production systems use a 'rank' map and greedy longest-match;
here we stick to the clearest didactic approach.
"""
progress = tqdm(range(len(merges)), desc="Applying BPE merges", ncols=80)
seq = corpus_to_char_seq_with_eow(text)
for a, b in merges:
seq = merge_pair_in_sequence(seq, (a, b))
progress.update(1)
progress.close()
return seq, [vocab[tok] for tok in seq]
corpus = "This is the best CS class. This is CS 189 the best class."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
encoded_seq, token_ids = bpe_encode("CS 189 is the best class.", merges, vocab)
print("Encoded sequence:", encoded_seq)
print("Token IDs:", token_ids)
Learning BPE merges: 19%|▏| 19/100 [00:00<00:00, 1672.65it/s, stopped (min freq Applying BPE merges: 100%|██████████████████| 19/19 [00:00<00:00, 165335.63it/s]
Encoded sequence: ['CS</w>', '1', '8', '9', '</w>', 'is</w>', 'the</w>', 'best</w>', 'class.</w>'] Token IDs: [107, 47, 29, 52, 70, 96, 101, 105, 113]
In [21]:
def bpe_decode(token_ids: List[int], vocab: Dict[str, int]) -> str:
"""
Decode token IDs back to text by inverting the vocab and then
removing EOW markers to re-insert spaces.
Rules:
- Tokens that END with EOW represent end-of-word units.
We strip the trailing `</w>` and insert a space.
- Other tokens are just literal substrings inside a word.
Caveat:
- Because we concatenated strings to form merged tokens, decoding simply
concatenates their surfaces; then we rely on `</w>` to restore spaces.
"""
inv_vocab = {i: t.replace(EOW, " ") for t, i in vocab.items()}
out_words: List[str] = []
buf = [inv_vocab[tid] for tid in token_ids]
return "".join(buf).strip()
In [22]:
decoded_text = bpe_decode(token_ids, vocab)
print(f"Decoded text: \"{decoded_text}\"")
Decoded text: "CS 189 is the best class."
Implementing the Decoder Transformer for Generative Pre-training
Getting the Data
In [23]:
import os
if not os.path.exists("shakespeare.txt"):
print('downloading corpus...')
import requests
url = "https://www.gutenberg.org/cache/epub/100/pg100.txt"
response = requests.get(url)
shakespeare_corpus = response.text
with open("shakespeare.txt", "w") as f:
f.write(shakespeare_corpus)
else:
print('loading cached file...')
with open("shakespeare.txt", "r") as f:
shakespeare_corpus = f.read()
print(f"Corpus length: {len(shakespeare_corpus)} characters")
print(shakespeare_corpus[:1000])
downloading corpus...
Corpus length: 5575062 characters
The Project Gutenberg eBook of The Complete Works of William Shakespeare
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
of the Project Gutenberg License included with this ebook or online
at www.gutenberg.org. If you are not located in the United States,
you will have to check the laws of the country where you are located
before using this eBook.
Title: The Complete Works of William Shakespeare
Author: William Shakespeare
Release date: January 1, 1994 [eBook #100]
Most recently updated: August 24, 2025
Language: English
*** START OF THE PROJECT GUTENBERG EBOOK THE COMPLETE WORKS OF WILLIAM SHAKESPEARE ***
The Complete Works of William Shakespeare
by William Shakespeare
Contents
THE SONNETS
ALL’S WELL THAT ENDS WELL
Byte Pair Encoding
In [24]:
# if not os.path.exists("bpe_state.pkl"):
# print('learning BPE merges on Shakespeare corpus...')
# merges, vocab = learn_bpe_merges(shakespeare_corpus,
# num_merges=200,
# min_frequency=2)
# with open("bpe_state.pkl", "wb") as f:
# pickle.dump((merges, vocab), f)
# else:
# print("loading cached BPE state...")
# with open("bpe_state.pkl", "rb") as f:
# merges, vocab = pickle.load(f)
# print("Learned merges:", merges)
# print("Vocabulary:", vocab)
# vocab_size = len(vocab)
# print("Vocabulary size:", vocab_size)
In [25]:
# if not os.path.exists("encoded_text_ids.pkl"):
# print('encoding Shakespeare corpus...')
# encoded_seq, token_ids = encode(shakespeare_corpus, merges, vocab)
# with open("encoded_text_ids.pkl", "wb") as f:
# pickle.dump((encoded_seq,token_ids), f)
# else:
# print("loading cached encoded text...")
# with open("encoded_text_ids.pkl", "rb") as f:
# encoded_seq, token_ids = pickle.load(f)
# print("Encoded sequence length:", len(encoded_seq))
# corpus_tokens = torch.tensor(token_ids, dtype=torch.long, device=device)
# def encode(text: str) -> torch.Tensor:
# _, token_ids = bpe_encode(text, merges, vocab)
# return torch.tensor(token_ids, dtype=torch.long, device=device)
# def decode(token_ids: torch.Tensor) -> str:
# return bpe_decode(token_ids.tolist(), vocab)
# tok = encode("To be, or not to be, that is the question.")
# decode(tok)
Word Encoding
In [26]:
import re
from collections import Counter
def split_words(corpus: str) -> List[str]:
"""
Break the corpus into words using a regex that matches word characters.
"""
pattern = r'\b\w+\b'
corpus = re.sub(r'[._,!?;"\'`()\[\]{}<>]', '', corpus.lower())
return re.findall(pattern, corpus.lower())
words = split_words(shakespeare_corpus)
# counter = Counter(words)
# vocab_set = {tok for tok, cnt in counter.items() if cnt > 1}
vocab_set = set(words)
vocab = {word: i for i, word in enumerate(sorted(vocab_set), start = 1)}
vocab["<unknown>"] = 1
inv_vocab = {i: word for word, i in vocab.items()}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)
def encode(text: str):
"""
Encode a string into token IDs using the provided vocabulary.
Unknown words are mapped to the ID for <unknown>.
"""
words = split_words(text)
return torch.tensor(
[vocab.get(word, 1) for word in words],
dtype=torch.long, device=device)
def decode(tokens: torch.Tensor) -> str:
"""
Decode token IDs back to text by inverting the vocabulary.
"""
words = [inv_vocab.get(t.item(), "<error>") for t in tokens]
return " ".join(words)
corpus_tokens = encode(shakespeare_corpus)
print("Length of token IDs:", len(corpus_tokens))
decode(encode("to be or not to be that is the question tokenizer"))
Vocabulary size: 25119 Length of token IDs: 992077
Out[26]:
'to be or not to be that is the question <unknown>'
Data Preparation
In [27]:
N = corpus_tokens.shape[0]
seq_length = 64
split_ratio = 0.90
seed = 189
x = corpus_tokens[:(N - (N % seq_length) + 1)]
y = x[1:].reshape(-1, seq_length)
x = x[:-1].reshape(-1, seq_length)
from torch.utils.data import random_split, TensorDataset
dataset = TensorDataset(x, y)
generator = torch.Generator().manual_seed(seed)
training_data, validation_data = random_split(
dataset, [split_ratio, 1 - split_ratio],
generator=generator)
print("training contexts", len(training_data))
print("validation contexts", len(validation_data))
training_data[0]
training contexts 13951 validation contexts 1550
Out[27]:
(tensor([14864, 21972, 19505, 20758, 1169, 19278, 10750, 3106, 8358, 3024,
7282, 14274, 9577, 22082, 24681, 1169, 12777, 11447, 12855, 11907,
21880, 21972, 19505, 8635, 23586, 21876, 21470, 15025, 449, 14747,
2166, 21869, 14556, 14572, 7882, 8134, 14794, 14572, 7882, 19499,
9993, 9970, 24472, 10466, 10524, 15491, 22710, 22178, 15092, 10588,
13003, 4518, 22178, 20425, 22178, 10750, 1169, 11250, 19549, 2112,
1538, 9743, 1538, 19549], device='cuda:0'),
tensor([21972, 19505, 20758, 1169, 19278, 10750, 3106, 8358, 3024, 7282,
14274, 9577, 22082, 24681, 1169, 12777, 11447, 12855, 11907, 21880,
21972, 19505, 8635, 23586, 21876, 21470, 15025, 449, 14747, 2166,
21869, 14556, 14572, 7882, 8134, 14794, 14572, 7882, 19499, 9993,
9970, 24472, 10466, 10524, 15491, 22710, 22178, 15092, 10588, 13003,
4518, 22178, 20425, 22178, 10750, 1169, 11250, 19549, 2112, 1538,
9743, 1538, 19549, 18805], device='cuda:0'))
In [28]:
print(decode(training_data[0][0]))
print(decode(training_data[0][1]))
now thou shalt stay and see her bright eyes break each morning gainst thy window and let in life into thee thou shalt feed upon the sweetness of a noble beauty that nature ne er exceeded nor ne er shall good gods what happiness has palamon twenty to one he ll come to speak to her and if she be as gentle as she thou shalt stay and see her bright eyes break each morning gainst thy window and let in life into thee thou shalt feed upon the sweetness of a noble beauty that nature ne er exceeded nor ne er shall good gods what happiness has palamon twenty to one he ll come to speak to her and if she be as gentle as she s
Attention
In [29]:
def scaled_dot_product_attention(Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask=None):
"""
Q: matrix of shape (B, N, d_k)
K: matrix of shape (B, N, d_k)
V: matrix of shape (B, N, d_v)
mask: boolean matrix of shape (1, N, N). Values where mask is True will be INCLUDED
"""
(B, N, d_k) = Q.shape
(_, _, d_v) = V.shape
K_T = K.transpose(-2, -1) # (B, d_k, N)
dot_product = (Q @ K_T) / (d_k ** 0.5) # (B, N, N)
if mask is not None:
dot_product = dot_product.masked_fill(mask.logical_not(), float('-inf'))
attention = F.softmax(dot_product, dim=-1) # (B, N, N)
return attention @ V # (B, N, N) * (B, N, d_v) = (B, N, d_v)
In [30]:
_mask_cache = {}
def get_mask_with_cache(N, device):
"""
Returns a lower triangular mask of shape (1, N, N) to be used for masked attention.
"""
if N not in _mask_cache:
_mask_cache[N] = torch.ones(
(N, N), dtype=torch.bool,
device=device).tril().unsqueeze(0)
return _mask_cache[N] # (1, N, N)
class MaskedAttentionHead(nn.Module):
def __init__(self, d_model=512, d_v=512, d_k=64):
super().__init__()
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.W_k = nn.Linear(self.d_model, self.d_k, bias = False)
self.W_q = nn.Linear(self.d_model, self.d_k, bias = False)
self.W_v = nn.Linear(self.d_model, self.d_v, bias = False)
def forward(self, x):
"""
x is the input to use for the queries, keys, and values
encoder_output is the output from the encoder (used for cross-attention)
mask is the mask to use for the attention
"""
(B, N, _) = x.shape
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
mask = get_mask_with_cache(N, device=x.device)
values = scaled_dot_product_attention(Q, K, V, mask=mask)
# ## more efficient implementation:
# ## Need to unsqueeze the head dimension for F.scaled_dot_product_attention
# values = F.scaled_dot_product_attention(
# query=Q, key=K, value=V,
# attn_mask=mask)
return values
In [31]:
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, num_heads=8, d_model=512, d_k=64):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_k
self.d_v = d_model // num_heads
self.attention_heads = nn.ModuleList(
[
MaskedAttentionHead(d_model=self.d_model, d_v = self.d_v, d_k=self.d_k)
for _ in range(self.num_heads)
]
)
# Projection
self.W_out = nn.Linear(self.num_heads * self.d_v, self.d_model)
def forward(self, x):
(B, N, _) = x.shape
assert(x.shape == (B, N, self.d_model))
head_outputs = [head(x) for head in self.attention_heads]
for head_out in head_outputs:
assert(head_out.shape == (B, N, self.d_v))
concatenated = torch.cat(head_outputs, dim=-1)
assert(concatenated.shape == (B, N, self.num_heads * self.d_v))
out = self.W_out(concatenated)
return out
Testing the Layer with Masked Multi-Head Attention:
In [32]:
emb = nn.Embedding(vocab_size, 512).to(device)
layer = MaskedMultiHeadAttention(7, 512, 64).to(device)
In [33]:
batch_size = 7
x, y = training_data[:batch_size]
emb(x).shape
Out[33]:
torch.Size([7, 64, 512])
In [34]:
layer(emb(x)).shape
Out[34]:
torch.Size([7, 64, 512])
Decoder Architecture
In [35]:
class DecoderBlock(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_k=64, dropout=0.1):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_ffn = 4 * self.d_model
self.d_k = d_k
self.dropout = nn.Dropout(dropout)
self.mh_attention = MaskedMultiHeadAttention(
num_heads=self.num_heads,
d_model=self.d_model,
d_k=self.d_k)
self.ffn = nn.Sequential(
nn.Linear(self.d_model, self.d_ffn),
nn.ReLU(),
self.dropout,
nn.Linear(self.d_ffn, self.d_model),
)
self.layernorm1 = nn.LayerNorm(self.d_model)
self.layernorm2 = nn.LayerNorm(self.d_model)
def forward(self, x):
mh = self.mh_attention(self.layernorm1(x)) # Prenorm
mh = self.dropout(mh)
x = x + mh
ffn = self.ffn(self.layernorm2(x)) # Prenorm
ffn = self.dropout(ffn)
return x
In [36]:
block = DecoderBlock(d_model=512, num_heads=8, d_k=64, dropout=0.1).to(device)
block(emb(x)).shape
Out[36]:
torch.Size([7, 64, 512])
In [37]:
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int = 512, max_len: int = 1024, L: float = 10000.0):
"""
Sinusoidal positional encoding as in 'Attention is All You Need'.
"""
super().__init__()
pos = torch.zeros(max_len, d_model, dtype=torch.float32)
positions = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
div_terms = L ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
quotient = positions / div_terms
pos[:, 0::2] = torch.sin(quotient) # even indices
pos[:, 1::2] = torch.cos(quotient) # odd indices
# Register as non-parameter buffer so it moves with the module
self.register_buffer("pos", pos)
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(1)
# pos: (1, seq_len, d_model) broadcasts along batch dimension
return x + self.pos[:seq_len].unsqueeze(0)
In [38]:
class TransformerDecoderOnly(nn.Module):
def __init__(self,
max_length=1024,
vocab_size=6000,
d_model=512,
d_k=64,
num_layers=6,
num_heads=8,
dropout=0.1):
super().__init__()
self.vocab_size = vocab_size
self.d_k = d_k
self.d_model = d_model
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.layers = nn.Sequential()
self.layers.append(
nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
)
self.layers.append(
PositionalEncoding(d_model=d_model, max_len=max_length, L=10000)
)
for _ in range(num_layers):
self.layers.append(
DecoderBlock(d_model=d_model, num_heads=num_heads,
d_k=self.d_k, dropout=self.dropout)
)
self.layers.append(
nn.Linear(in_features=d_model, out_features=vocab_size)
)
def num_parameters(self):
return sum(p.numel() for p in self.parameters())
def forward(self, x):
return self.layers(x)
def generate(self, x, max_new_tokens):
"""
x: (B, N) tensor of input token IDs
max_new_tokens: number of tokens to generate
"""
self.eval()
with torch.no_grad():
B, N = x.shape
for _ in range(max_new_tokens):
x = x[:, -seq_length:] # crop to last seq_length tokens
logits = self.forward(x) # (B, N, vocab_size)
next_token_logits = logits[:, -1, :] # (B, vocab_size)
next_token_probs = F.softmax(next_token_logits, dim=-1) # (B, vocab_size)
next_tokens = torch.multinomial(next_token_probs, num_samples=1) # (B, 1)
x = torch.cat([x, next_tokens], dim=1) # (B, N+1)
return x
In [39]:
model = TransformerDecoderOnly(
max_length=seq_length,
vocab_size=vocab_size,
d_model=256, d_k=16, num_layers=6, num_heads=8,
dropout=0.1).to(device)
print(model)
print("Number of parameters:", model.num_parameters()/1e6, "million")
TransformerDecoderOnly(
(layers): Sequential(
(0): Embedding(25119, 256)
(1): PositionalEncoding()
(2): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(3): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(4): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(5): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(6): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(7): DecoderBlock(
(dropout): Dropout(p=0.1, inplace=False)
(mh_attention): MaskedMultiHeadAttention(
(attention_heads): ModuleList(
(0-7): 8 x MaskedAttentionHead(
(W_k): Linear(in_features=256, out_features=16, bias=False)
(W_q): Linear(in_features=256, out_features=16, bias=False)
(W_v): Linear(in_features=256, out_features=32, bias=False)
)
)
(W_out): Linear(in_features=256, out_features=256, bias=True)
)
(ffn): Sequential(
(0): Linear(in_features=256, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=1024, out_features=256, bias=True)
)
(layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(8): Linear(in_features=256, out_features=25119, bias=True)
)
)
Number of parameters: 17.226783 million
In [40]:
decode(model.generate(encode("to be or not to be").unsqueeze(0), max_new_tokens=20)[0])
Out[40]:
'to be or not to be unparalleled yorkshire revenging types worshipp descension statesman addition reedy festinate camels unlink attain captainship agitation granted righteous sparrows grace appendix'
Training Loop
In [41]:
def batch_cross_entropy(pred, y):
# flatten the batch into a single dimension and
# compute cross-entropy
return F.cross_entropy(pred.view(-1, vocab_size), y.view(-1))
In [42]:
batch_cross_entropy(model(x), y)
Out[42]:
tensor(10.3724, device='cuda:0', grad_fn=<NllLossBackward0>)
In [43]:
def minibatch_gd(model, loss_fn,
training_data,
batch_size,
nsteps,
learning_rate,
visualizer=None,
weight_decay=1e-4):
generator = torch.Generator()
generator.manual_seed(189)
loader = DataLoader(training_data,
batch_size=batch_size,
shuffle=True, # shuffles each epoch
generator=generator)
# Define the optimizer (this is the update rule)
# Alternatively, you can use Adam optimizer
optimizer = torch.optim.AdamW(model.parameters(), learning_rate, weight_decay=weight_decay)
model.train() # set model to training mode (important for dropout/batchnorm)
step = 0
# Loop through the steps
iter_loader = iter(loader)
for step in tqdm(range(nsteps)):
# Get the next batch of data
try:
x, t = next(iter_loader)
except StopIteration:
iter_loader = iter(loader)
x, t = next(iter_loader)
# Zero the gradients to start the next step
optimizer.zero_grad()
# Compute prediction and loss
pred = model(x)
loss = loss_fn(pred, t)
tr_loss = loss.item()
# Backpropagation (compute the gradient)
loss.backward()
# Update the parameters using the optimizer's update rule
optimizer.step()
# Visualize the model (if a visualizer function is provided)
if visualizer is not None:
model.eval() # disable dropout/batchnorm
with torch.no_grad():
visualizer(step, model, loss_fn, tr_loss)
model.train()
In [44]:
class LossVisualizer:
def __init__(self, loss_fig, validation_data):
self.loss_fig = loss_fig
self.val_loader = DataLoader(validation_data,
batch_size=32,
shuffle=False)
self.epochs = []
self.losses_val = []
self.losses_tr = []
def reset(self):
self.epochs = []
self.losses_val = []
self.losses_tr = []
with self.loss_fig.batch_update():
self.loss_fig.data[0].x = []
self.loss_fig.data[0].y = []
self.loss_fig.data[1].x = []
self.loss_fig.data[1].y = []
def __call__(self, epoch, model, loss_fn, loss_tr):
model.eval()
with torch.no_grad():
losses = []
for x_val, t_val in self.val_loader:
loss_val = loss_fn(model(x_val), t_val).item()
losses.append(loss_val)
loss_val = np.mean(losses)
self.epochs.append(epoch)
self.losses_val.append(loss_val)
self.losses_tr.append(loss_tr)
print("training loss:", loss_tr, "validation loss:", loss_val)
# Visualization Code
with self.loss_fig.batch_update():
self.loss_fig.data[0].x = self.epochs
self.loss_fig.data[0].y = self.losses_val
self.loss_fig.data[1].x = self.epochs
self.loss_fig.data[1].y = self.losses_tr
model.train()
In [45]:
loss_fig = go.FigureWidget()
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Val. Loss'))
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Train. Loss'))
visualizer = LossVisualizer(loss_fig, validation_data)
display(loss_fig)
FigureWidget({
'data': [{'mode': 'lines',
'name': 'Val. Loss',
'type': 'scatter',
'uid': '5ec38f84-7c4d-435f-8fb7-56afe996d547',
'x': [0],
'y': [0]},
{'mode': 'lines',
'name': 'Train. Loss',
'type': 'scatter',
'uid': '1f185c4c-8c4e-49d1-a8b2-e305ae279857',
'x': [0],
'y': [0]}],
'layout': {'template': '...'}
})
In [48]:
visualizer.reset()
model = TransformerDecoderOnly(
max_length=seq_length,
vocab_size=vocab_size,
d_model=1024, d_k=32, num_layers=2, num_heads=8,
dropout=0.0).to(device)
#model = torch.compile(model)
# model = TransformerDecoderOnly(
# max_length=seq_length,
# vocab_size=vocab_size,
# d_model=128, d_k=32, num_layers=8, num_heads=8,
# dropout=0.1).to(device)
minibatch_gd(
model=model,
loss_fn=batch_cross_entropy,
training_data=training_data,
batch_size=128,
nsteps=200,
learning_rate=3e-4,
weight_decay=1e-7,
visualizer=visualizer
)
0%| | 1/200 [00:00<02:57, 1.12it/s]
training loss: 10.361719131469727 validation loss: 9.929148265293666
1%| | 2/200 [00:01<02:53, 1.14it/s]
training loss: 9.939773559570312 validation loss: 9.46023094410799
2%|▏ | 3/200 [00:02<02:50, 1.15it/s]
training loss: 9.456949234008789 validation loss: 8.856558429951571
2%|▏ | 4/200 [00:03<02:49, 1.15it/s]
training loss: 8.850252151489258 validation loss: 8.335775414291692
2%|▎ | 5/200 [00:04<02:48, 1.16it/s]
training loss: 8.29616641998291 validation loss: 7.993326848866988
3%|▎ | 6/200 [00:05<02:47, 1.16it/s]
training loss: 7.9955973625183105 validation loss: 7.675127457599251
4%|▎ | 7/200 [00:06<02:46, 1.16it/s]
training loss: 7.692054271697998 validation loss: 7.490796225411551
4%|▍ | 8/200 [00:06<02:45, 1.16it/s]
training loss: 7.510016441345215 validation loss: 7.413140131502735
4%|▍ | 9/200 [00:07<02:44, 1.16it/s]
training loss: 7.368573188781738 validation loss: 7.367588490855937
5%|▌ | 10/200 [00:08<02:43, 1.16it/s]
training loss: 7.339788436889648 validation loss: 7.327013706674381
6%|▌ | 11/200 [00:09<02:42, 1.16it/s]
training loss: 7.325085639953613 validation loss: 7.282361157086431
6%|▌ | 12/200 [00:10<02:41, 1.16it/s]
training loss: 7.287446022033691 validation loss: 7.238322219070123
6%|▋ | 13/200 [00:11<02:40, 1.16it/s]
training loss: 7.1515913009643555 validation loss: 7.2038816335249924
7%|▋ | 14/200 [00:12<02:39, 1.17it/s]
training loss: 7.1626057624816895 validation loss: 7.165999344417027
8%|▊ | 15/200 [00:12<02:38, 1.16it/s]
training loss: 7.131077766418457 validation loss: 7.128883332622294
8%|▊ | 16/200 [00:13<02:37, 1.16it/s]
training loss: 7.13786506652832 validation loss: 7.104381113636251
8%|▊ | 17/200 [00:14<02:36, 1.17it/s]
training loss: 7.147603511810303 validation loss: 7.085897280245411
9%|▉ | 18/200 [00:15<02:36, 1.17it/s]
training loss: 7.146273612976074 validation loss: 7.067882323751644
10%|▉ | 19/200 [00:16<02:35, 1.17it/s]
training loss: 7.116576194763184 validation loss: 7.046328213750099
10%|█ | 20/200 [00:17<02:34, 1.16it/s]
training loss: 7.081804275512695 validation loss: 7.0230551447187155
10%|█ | 21/200 [00:18<02:34, 1.16it/s]
training loss: 7.080620765686035 validation loss: 7.001611417653609
11%|█ | 22/200 [00:18<02:33, 1.16it/s]
training loss: 7.038182735443115 validation loss: 6.987391355086346
12%|█▏ | 23/200 [00:19<02:31, 1.16it/s]
training loss: 6.958187580108643 validation loss: 6.9772258875321365
12%|█▏ | 24/200 [00:20<02:31, 1.17it/s]
training loss: 7.008291244506836 validation loss: 6.961357710312824
12%|█▎ | 25/200 [00:21<02:30, 1.16it/s]
training loss: 6.9798808097839355 validation loss: 6.942648644350013
13%|█▎ | 26/200 [00:22<02:29, 1.16it/s]
training loss: 6.94419002532959 validation loss: 6.927445985832993
14%|█▎ | 27/200 [00:23<02:29, 1.16it/s]
training loss: 6.9823102951049805 validation loss: 6.912888059810716
14%|█▍ | 28/200 [00:24<02:28, 1.16it/s]
training loss: 6.891851902008057 validation loss: 6.893563085672807
14%|█▍ | 29/200 [00:24<02:27, 1.16it/s]
training loss: 6.838163375854492 validation loss: 6.876966379126724
15%|█▌ | 30/200 [00:25<02:26, 1.16it/s]
training loss: 6.8766961097717285 validation loss: 6.8654066299905585
16%|█▌ | 31/200 [00:26<02:25, 1.16it/s]
training loss: 6.861605167388916 validation loss: 6.852149145943778
16%|█▌ | 32/200 [00:27<02:24, 1.16it/s]
training loss: 6.7996416091918945 validation loss: 6.837655349653595
16%|█▋ | 33/200 [00:28<02:23, 1.16it/s]
training loss: 6.846062183380127 validation loss: 6.826757275328344
17%|█▋ | 34/200 [00:29<02:23, 1.16it/s]
training loss: 6.7678632736206055 validation loss: 6.815719273625588
18%|█▊ | 35/200 [00:30<02:22, 1.16it/s]
training loss: 6.793147087097168 validation loss: 6.804327663110227
18%|█▊ | 36/200 [00:31<02:21, 1.16it/s]
training loss: 6.738554000854492 validation loss: 6.795834551052171
18%|█▊ | 37/200 [00:31<02:20, 1.16it/s]
training loss: 6.736570358276367 validation loss: 6.788319529319296
19%|█▉ | 38/200 [00:32<02:19, 1.16it/s]
training loss: 6.753721237182617 validation loss: 6.776969753966039
20%|█▉ | 39/200 [00:33<02:18, 1.16it/s]
training loss: 6.767504692077637 validation loss: 6.764109650436713
20%|██ | 40/200 [00:34<02:18, 1.15it/s]
training loss: 6.893495082855225 validation loss: 6.758010192793243
20%|██ | 41/200 [00:35<02:17, 1.15it/s]
training loss: 6.77290678024292 validation loss: 6.749037012761953
21%|██ | 42/200 [00:36<02:17, 1.15it/s]
training loss: 6.768361568450928 validation loss: 6.737954120246732
22%|██▏ | 43/200 [00:37<02:16, 1.15it/s]
training loss: 6.690580368041992 validation loss: 6.730391210439254
22%|██▏ | 44/200 [00:37<02:15, 1.15it/s]
training loss: 6.812774658203125 validation loss: 6.720891534065713
22%|██▎ | 45/200 [00:38<02:15, 1.15it/s]
training loss: 6.717848777770996 validation loss: 6.709689646351094
23%|██▎ | 46/200 [00:39<02:13, 1.15it/s]
training loss: 6.644081115722656 validation loss: 6.701014956649469
24%|██▎ | 47/200 [00:40<02:13, 1.15it/s]
training loss: 6.717618465423584 validation loss: 6.692500007395842
24%|██▍ | 48/200 [00:41<02:12, 1.14it/s]
training loss: 6.673792839050293 validation loss: 6.6824401933319715
24%|██▍ | 49/200 [00:42<02:12, 1.14it/s]
training loss: 6.780778408050537 validation loss: 6.672941314930818
25%|██▌ | 50/200 [00:43<02:11, 1.14it/s]
training loss: 6.703185081481934 validation loss: 6.665250885243318
26%|██▌ | 51/200 [00:44<02:11, 1.14it/s]
training loss: 6.724969863891602 validation loss: 6.658390424689468
26%|██▌ | 52/200 [00:44<02:10, 1.14it/s]
training loss: 6.611624717712402 validation loss: 6.650597085758132
26%|██▋ | 53/200 [00:45<02:09, 1.13it/s]
training loss: 6.6266632080078125 validation loss: 6.642186018885399
27%|██▋ | 54/200 [00:46<02:08, 1.13it/s]
training loss: 6.738619327545166 validation loss: 6.635702522433534
28%|██▊ | 55/200 [00:47<02:07, 1.13it/s]
training loss: 6.580887794494629 validation loss: 6.6281601360866
28%|██▊ | 56/200 [00:48<02:07, 1.13it/s]
training loss: 6.584505081176758 validation loss: 6.619589659632469
28%|██▊ | 57/200 [00:49<02:06, 1.13it/s]
training loss: 6.646501541137695 validation loss: 6.612708743737668
29%|██▉ | 58/200 [00:50<02:05, 1.13it/s]
training loss: 6.598003387451172 validation loss: 6.60520042691912
30%|██▉ | 59/200 [00:51<02:03, 1.14it/s]
training loss: 6.551400184631348 validation loss: 6.597341722371627
30%|███ | 60/200 [00:52<02:02, 1.14it/s]
training loss: 6.596179962158203 validation loss: 6.59074356118027
30%|███ | 61/200 [00:52<02:02, 1.13it/s]
training loss: 6.584847450256348 validation loss: 6.583917608066481
31%|███ | 62/200 [00:53<02:01, 1.13it/s]
training loss: 6.4984588623046875 validation loss: 6.576665099786252
32%|███▏ | 63/200 [00:54<02:00, 1.14it/s]
training loss: 6.626906871795654 validation loss: 6.569135636699443
32%|███▏ | 64/200 [00:55<01:59, 1.14it/s]
training loss: 6.555837154388428 validation loss: 6.562342857827946
32%|███▎ | 65/200 [00:56<01:58, 1.14it/s]
training loss: 6.533874034881592 validation loss: 6.5565643602487995
33%|███▎ | 66/200 [00:57<01:56, 1.15it/s]
training loss: 6.563194751739502 validation loss: 6.550891428577657
34%|███▎ | 67/200 [00:58<01:56, 1.14it/s]
training loss: 6.503715515136719 validation loss: 6.543573175157819
34%|███▍ | 68/200 [00:59<01:55, 1.15it/s]
training loss: 6.514548301696777 validation loss: 6.535555382164157
34%|███▍ | 69/200 [00:59<01:54, 1.15it/s]
training loss: 6.554697036743164 validation loss: 6.5284194459720535
35%|███▌ | 70/200 [01:00<01:53, 1.15it/s]
training loss: 6.534619331359863 validation loss: 6.5230604385843085
36%|███▌ | 71/200 [01:01<01:52, 1.15it/s]
training loss: 6.466627597808838 validation loss: 6.517006436172797
36%|███▌ | 72/200 [01:02<01:51, 1.15it/s]
training loss: 6.492819786071777 validation loss: 6.510690523653614
36%|███▋ | 73/200 [01:03<01:50, 1.15it/s]
training loss: 6.507439136505127 validation loss: 6.503489834921701
37%|███▋ | 74/200 [01:04<01:49, 1.15it/s]
training loss: 6.539058208465576 validation loss: 6.497100265658632
38%|███▊ | 75/200 [01:05<01:48, 1.15it/s]
training loss: 6.465414047241211 validation loss: 6.490775984160754
38%|███▊ | 76/200 [01:05<01:47, 1.15it/s]
training loss: 6.4486589431762695 validation loss: 6.484331160175557
38%|███▊ | 77/200 [01:06<01:46, 1.15it/s]
training loss: 6.4612884521484375 validation loss: 6.478635593336456
39%|███▉ | 78/200 [01:07<01:45, 1.16it/s]
training loss: 6.446281433105469 validation loss: 6.473315920148577
40%|███▉ | 79/200 [01:08<01:44, 1.16it/s]
training loss: 6.434588432312012 validation loss: 6.467980958977524
40%|████ | 80/200 [01:09<01:43, 1.16it/s]
training loss: 6.499963760375977 validation loss: 6.463095120021275
40%|████ | 81/200 [01:10<01:42, 1.16it/s]
training loss: 6.494248867034912 validation loss: 6.459504156696553
41%|████ | 82/200 [01:11<01:42, 1.15it/s]
training loss: 6.364786148071289 validation loss: 6.454968384334019
42%|████▏ | 83/200 [01:12<01:41, 1.15it/s]
training loss: 6.4450578689575195 validation loss: 6.4510873191210685
42%|████▏ | 84/200 [01:12<01:41, 1.15it/s]
training loss: 6.448891639709473 validation loss: 6.445479120526995
42%|████▎ | 85/200 [01:13<01:40, 1.14it/s]
training loss: 6.4035844802856445 validation loss: 6.441236778181427
43%|████▎ | 86/200 [01:14<01:39, 1.14it/s]
training loss: 6.477339267730713 validation loss: 6.438299616988824
44%|████▎ | 87/200 [01:15<01:39, 1.14it/s]
training loss: 6.415297508239746 validation loss: 6.435391718027543
44%|████▍ | 88/200 [01:16<01:38, 1.14it/s]
training loss: 6.430819988250732 validation loss: 6.431323148766342
44%|████▍ | 89/200 [01:17<01:37, 1.14it/s]
training loss: 6.412506103515625 validation loss: 6.426331520080566
45%|████▌ | 90/200 [01:18<01:36, 1.14it/s]
training loss: 6.434972763061523 validation loss: 6.420756680624826
46%|████▌ | 91/200 [01:19<01:34, 1.15it/s]
training loss: 6.411623001098633 validation loss: 6.416116889642209
46%|████▌ | 92/200 [01:19<01:33, 1.15it/s]
training loss: 6.3786115646362305 validation loss: 6.411592911700813
46%|████▋ | 93/200 [01:20<01:32, 1.16it/s]
training loss: 6.417270660400391 validation loss: 6.407266937956518
47%|████▋ | 94/200 [01:21<01:31, 1.15it/s]
training loss: 6.399571895599365 validation loss: 6.403549680904466
48%|████▊ | 95/200 [01:22<01:30, 1.16it/s]
training loss: 6.377954959869385 validation loss: 6.399578201527498
48%|████▊ | 96/200 [01:23<01:29, 1.16it/s]
training loss: 6.492927551269531 validation loss: 6.395020825522287
48%|████▊ | 97/200 [01:24<01:28, 1.16it/s]
training loss: 6.409252166748047 validation loss: 6.390357270532725
49%|████▉ | 98/200 [01:25<01:27, 1.16it/s]
training loss: 6.410403728485107 validation loss: 6.385524010171696
50%|████▉ | 99/200 [01:25<01:27, 1.15it/s]
training loss: 6.511599540710449 validation loss: 6.381968644200539
50%|█████ | 100/200 [01:26<01:26, 1.15it/s]
training loss: 6.407318592071533 validation loss: 6.378166646373515
50%|█████ | 101/200 [01:27<01:25, 1.15it/s]
training loss: 6.364468097686768 validation loss: 6.374616914865922
51%|█████ | 102/200 [01:28<01:24, 1.16it/s]
training loss: 6.477817535400391 validation loss: 6.371596628305864
52%|█████▏ | 103/200 [01:29<01:23, 1.16it/s]
training loss: 6.411909103393555 validation loss: 6.368403979710171
52%|█████▏ | 104/200 [01:30<01:23, 1.16it/s]
training loss: 6.30910062789917 validation loss: 6.364959317810682
52%|█████▎ | 105/200 [01:31<01:22, 1.15it/s]
training loss: 6.369457721710205 validation loss: 6.3607834796516265
53%|█████▎ | 106/200 [01:32<01:21, 1.15it/s]
training loss: 6.462528228759766 validation loss: 6.355573245457241
54%|█████▎ | 107/200 [01:32<01:20, 1.16it/s]
training loss: 6.37962532043457 validation loss: 6.351249480734066
54%|█████▍ | 108/200 [01:33<01:19, 1.16it/s]
training loss: 6.378788948059082 validation loss: 6.347711748006392
55%|█████▍ | 109/200 [01:34<01:19, 1.15it/s]
training loss: 6.42195987701416 validation loss: 6.344782994717968
55%|█████▌ | 110/200 [01:35<01:18, 1.15it/s]
training loss: 6.095822334289551 validation loss: 6.342060916277827
56%|█████▌ | 111/200 [01:36<01:17, 1.15it/s]
training loss: 6.107515335083008 validation loss: 6.34109827936912
56%|█████▌ | 112/200 [01:37<01:16, 1.14it/s]
training loss: 6.106581211090088 validation loss: 6.338152398868483
56%|█████▋ | 113/200 [01:38<01:15, 1.15it/s]
training loss: 6.084316730499268 validation loss: 6.333792900552555
57%|█████▋ | 114/200 [01:38<01:15, 1.14it/s]
training loss: 6.142121315002441 validation loss: 6.333415294180111
57%|█████▊ | 115/200 [01:39<01:13, 1.15it/s]
training loss: 6.144359111785889 validation loss: 6.329886485119255
58%|█████▊ | 116/200 [01:40<01:12, 1.15it/s]
training loss: 6.037834167480469 validation loss: 6.326496591373366
58%|█████▊ | 117/200 [01:41<01:11, 1.15it/s]
training loss: 6.078401565551758 validation loss: 6.327183791569301
59%|█████▉ | 118/200 [01:42<01:11, 1.15it/s]
training loss: 6.101098537445068 validation loss: 6.323671058732636
60%|█████▉ | 119/200 [01:43<01:10, 1.15it/s]
training loss: 6.168275833129883 validation loss: 6.320113259918836
60%|██████ | 120/200 [01:44<01:09, 1.15it/s]
training loss: 6.099752902984619 validation loss: 6.3197658013324345
60%|██████ | 121/200 [01:45<01:08, 1.15it/s]
training loss: 6.145869255065918 validation loss: 6.314804826463972
61%|██████ | 122/200 [01:45<01:07, 1.15it/s]
training loss: 6.151786804199219 validation loss: 6.311663053473648
62%|██████▏ | 123/200 [01:46<01:07, 1.15it/s]
training loss: 6.152758598327637 validation loss: 6.3094333434591485
62%|██████▏ | 124/200 [01:47<01:06, 1.15it/s]
training loss: 6.077962398529053 validation loss: 6.305732113974435
62%|██████▎ | 125/200 [01:48<01:05, 1.15it/s]
training loss: 5.982882499694824 validation loss: 6.304498010752153
63%|██████▎ | 126/200 [01:49<01:04, 1.15it/s]
training loss: 6.154467582702637 validation loss: 6.302595858671228
64%|██████▎ | 127/200 [01:50<01:03, 1.15it/s]
training loss: 6.038270950317383 validation loss: 6.302015421341877
64%|██████▍ | 128/200 [01:51<01:02, 1.15it/s]
training loss: 6.119851112365723 validation loss: 6.302074909210205
64%|██████▍ | 129/200 [01:52<01:01, 1.15it/s]
training loss: 6.008480072021484 validation loss: 6.298991728802116
65%|██████▌ | 130/200 [01:52<01:00, 1.15it/s]
training loss: 6.137060165405273 validation loss: 6.293819904327393
66%|██████▌ | 131/200 [01:53<01:00, 1.15it/s]
training loss: 6.07685661315918 validation loss: 6.292368694227569
66%|██████▌ | 132/200 [01:54<00:59, 1.15it/s]
training loss: 6.112682342529297 validation loss: 6.289677230679259
66%|██████▋ | 133/200 [01:55<00:58, 1.14it/s]
training loss: 6.083937644958496 validation loss: 6.286848691045021
67%|██████▋ | 134/200 [01:56<00:57, 1.14it/s]
training loss: 6.105483531951904 validation loss: 6.284933129135443
68%|██████▊ | 135/200 [01:57<00:56, 1.14it/s]
training loss: 6.02985143661499 validation loss: 6.283226859812834
68%|██████▊ | 136/200 [01:58<00:56, 1.14it/s]
training loss: 6.085984706878662 validation loss: 6.278696060180664
68%|██████▊ | 137/200 [01:59<00:55, 1.14it/s]
training loss: 6.073261260986328 validation loss: 6.275632702574438
69%|██████▉ | 138/200 [01:59<00:53, 1.15it/s]
training loss: 6.017243385314941 validation loss: 6.272366971385722
70%|██████▉ | 139/200 [02:00<00:52, 1.15it/s]
training loss: 6.022629261016846 validation loss: 6.27057543579413
70%|███████ | 140/200 [02:01<00:51, 1.15it/s]
training loss: 6.011186599731445 validation loss: 6.27148031701847
70%|███████ | 141/200 [02:02<00:50, 1.16it/s]
training loss: 6.033587455749512 validation loss: 6.269955070651307
71%|███████ | 142/200 [02:03<00:49, 1.16it/s]
training loss: 6.104765892028809 validation loss: 6.266687023396394
72%|███████▏ | 143/200 [02:04<00:49, 1.16it/s]
training loss: 6.068141937255859 validation loss: 6.263849774185492
72%|███████▏ | 144/200 [02:05<00:48, 1.16it/s]
training loss: 6.077755928039551 validation loss: 6.261609632141736
72%|███████▎ | 145/200 [02:05<00:47, 1.16it/s]
training loss: 6.078707695007324 validation loss: 6.260631152561733
73%|███████▎ | 146/200 [02:06<00:46, 1.16it/s]
training loss: 6.057844638824463 validation loss: 6.258478923719757
74%|███████▎ | 147/200 [02:07<00:45, 1.16it/s]
training loss: 6.060548782348633 validation loss: 6.257633919618567
74%|███████▍ | 148/200 [02:08<00:44, 1.16it/s]
training loss: 6.053906440734863 validation loss: 6.2561231827249335
74%|███████▍ | 149/200 [02:09<00:44, 1.16it/s]
training loss: 6.057756423950195 validation loss: 6.253395129223259
75%|███████▌ | 150/200 [02:10<00:43, 1.15it/s]
training loss: 6.016613006591797 validation loss: 6.251269593530772
76%|███████▌ | 151/200 [02:11<00:42, 1.15it/s]
training loss: 6.069310188293457 validation loss: 6.249930294192567
76%|███████▌ | 152/200 [02:11<00:41, 1.15it/s]
training loss: 6.057592391967773 validation loss: 6.248094733880491
76%|███████▋ | 153/200 [02:12<00:40, 1.15it/s]
training loss: 6.055517673492432 validation loss: 6.245763856537488
77%|███████▋ | 154/200 [02:13<00:39, 1.15it/s]
training loss: 5.981128692626953 validation loss: 6.2438530921936035
78%|███████▊ | 155/200 [02:14<00:39, 1.15it/s]
training loss: 6.022617340087891 validation loss: 6.241560556450668
78%|███████▊ | 156/200 [02:15<00:38, 1.15it/s]
training loss: 6.086309432983398 validation loss: 6.237579958779471
78%|███████▊ | 157/200 [02:16<00:37, 1.16it/s]
training loss: 6.069828510284424 validation loss: 6.234373160770962
79%|███████▉ | 158/200 [02:17<00:36, 1.16it/s]
training loss: 5.977034568786621 validation loss: 6.233016325502979
80%|███████▉ | 159/200 [02:18<00:35, 1.16it/s]
training loss: 6.02243185043335 validation loss: 6.231708945060263
80%|████████ | 160/200 [02:18<00:34, 1.16it/s]
training loss: 6.022791862487793 validation loss: 6.228585700599515
80%|████████ | 161/200 [02:19<00:33, 1.16it/s]
training loss: 5.982759475708008 validation loss: 6.226058765333526
81%|████████ | 162/200 [02:20<00:32, 1.16it/s]
training loss: 6.001495361328125 validation loss: 6.22391228773156
82%|████████▏ | 163/200 [02:21<00:31, 1.16it/s]
training loss: 6.003325462341309 validation loss: 6.221862325862962
82%|████████▏ | 164/200 [02:22<00:30, 1.16it/s]
training loss: 6.012252330780029 validation loss: 6.221921395282356
82%|████████▎ | 165/200 [02:23<00:30, 1.16it/s]
training loss: 5.909038543701172 validation loss: 6.222556221241853
83%|████████▎ | 166/200 [02:24<00:29, 1.16it/s]
training loss: 5.966092109680176 validation loss: 6.22004255956533
84%|████████▎ | 167/200 [02:24<00:28, 1.16it/s]
training loss: 5.964709758758545 validation loss: 6.217550267978591
84%|████████▍ | 168/200 [02:25<00:27, 1.17it/s]
training loss: 5.990265369415283 validation loss: 6.2139473642621725
84%|████████▍ | 169/200 [02:26<00:26, 1.16it/s]
training loss: 5.978246212005615 validation loss: 6.212402859512641
85%|████████▌ | 170/200 [02:27<00:25, 1.16it/s]
training loss: 5.977949142456055 validation loss: 6.213127720112703
86%|████████▌ | 171/200 [02:28<00:24, 1.16it/s]
training loss: 6.002535820007324 validation loss: 6.210467143934601
86%|████████▌ | 172/200 [02:29<00:24, 1.17it/s]
training loss: 6.001431941986084 validation loss: 6.210143994311897
86%|████████▋ | 173/200 [02:30<00:23, 1.16it/s]
training loss: 6.107723236083984 validation loss: 6.207548559928427
87%|████████▋ | 174/200 [02:30<00:22, 1.16it/s]
training loss: 5.944747447967529 validation loss: 6.205113342830113
88%|████████▊ | 175/200 [02:31<00:21, 1.16it/s]
training loss: 5.999008655548096 validation loss: 6.20351285350566
88%|████████▊ | 176/200 [02:32<00:20, 1.16it/s]
training loss: 5.943323135375977 validation loss: 6.200359597498057
88%|████████▊ | 177/200 [02:33<00:19, 1.16it/s]
training loss: 5.962564468383789 validation loss: 6.19809255794603
89%|████████▉ | 178/200 [02:34<00:18, 1.16it/s]
training loss: 5.994581699371338 validation loss: 6.1956459064872895
90%|████████▉ | 179/200 [02:35<00:18, 1.16it/s]
training loss: 6.089620113372803 validation loss: 6.194755612587442
90%|█████████ | 180/200 [02:36<00:17, 1.16it/s]
training loss: 5.90007209777832 validation loss: 6.194045397700096
90%|█████████ | 181/200 [02:36<00:16, 1.16it/s]
training loss: 5.999163627624512 validation loss: 6.1943637789512165
91%|█████████ | 182/200 [02:37<00:15, 1.16it/s]
training loss: 6.013882637023926 validation loss: 6.192244325365339
92%|█████████▏| 183/200 [02:38<00:14, 1.16it/s]
training loss: 5.918767929077148 validation loss: 6.190416705851653
92%|█████████▏| 184/200 [02:39<00:13, 1.17it/s]
training loss: 5.998378753662109 validation loss: 6.186842869739143
92%|█████████▎| 185/200 [02:40<00:12, 1.17it/s]
training loss: 6.001276016235352 validation loss: 6.184115536358892
93%|█████████▎| 186/200 [02:41<00:12, 1.16it/s]
training loss: 5.9881744384765625 validation loss: 6.182604663226069
94%|█████████▎| 187/200 [02:42<00:11, 1.16it/s]
training loss: 6.0272016525268555 validation loss: 6.180895309058988
94%|█████████▍| 188/200 [02:42<00:10, 1.16it/s]
training loss: 6.033154487609863 validation loss: 6.18033537572744
94%|█████████▍| 189/200 [02:43<00:09, 1.15it/s]
training loss: 6.017482280731201 validation loss: 6.178543343835948
95%|█████████▌| 190/200 [02:44<00:08, 1.16it/s]
training loss: 5.964412212371826 validation loss: 6.1761249425459885
96%|█████████▌| 191/200 [02:45<00:07, 1.16it/s]
training loss: 6.065142631530762 validation loss: 6.174429805911317
96%|█████████▌| 192/200 [02:46<00:06, 1.16it/s]
training loss: 5.94199275970459 validation loss: 6.173588704089729
96%|█████████▋| 193/200 [02:47<00:06, 1.16it/s]
training loss: 5.937205791473389 validation loss: 6.170328422468536
97%|█████████▋| 194/200 [02:48<00:05, 1.16it/s]
training loss: 5.961018085479736 validation loss: 6.168160039551404
98%|█████████▊| 195/200 [02:49<00:04, 1.16it/s]
training loss: 5.929530143737793 validation loss: 6.168551522858289
98%|█████████▊| 196/200 [02:49<00:03, 1.16it/s]
training loss: 5.922443389892578 validation loss: 6.167792475953394
98%|█████████▊| 197/200 [02:50<00:02, 1.16it/s]
training loss: 5.891782283782959 validation loss: 6.164439503027468
99%|█████████▉| 198/200 [02:51<00:01, 1.16it/s]
training loss: 5.946990966796875 validation loss: 6.16373381322744
100%|█████████▉| 199/200 [02:52<00:00, 1.16it/s]
training loss: 5.919736862182617 validation loss: 6.162640055831598
100%|██████████| 200/200 [02:53<00:00, 1.15it/s]
training loss: 5.924776077270508 validation loss: 6.159777845655169
In [49]:
decode(model.generate(encode("whether tis nobler").unsqueeze(0), max_new_tokens=20)[0])
Out[49]:
'whether tis nobler king our affection amongst our end eyebrows sends oblivion dead and let be at care foul my lord i is'
In [ ]: