Lecture 21 – CS 189, Fall 2025

In [3]:
# !pip install -U plotly
In [4]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pickle
In [5]:
# import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
device = "cuda" # Colab
# device = "mps" # Mac with M1/M2
# device = "cpu" # Local CPU
In [6]:
import plotly.io as pio
pio.renderers.default = "vscode"

Sinusoidal Embeddings

In [7]:
D = 6
n = 16
L = 1000
torch.arange(0, D, 2, dtype=torch.float)
Out[7]:
tensor([0., 2., 4.])
In [8]:
def sinusoidal_pe(N, D, L=1000):
    pe = torch.zeros(N, D)
    div_term = L ** (2 * torch.arange(0, D, 2, dtype=torch.float) / D)
    pe[:, 0::2] = torch.sin(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    pe[:, 1::2] = torch.cos(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    return pe
In [9]:
n = 64; D = 128; L = 10
pe = sinusoidal_pe(n, D, L)
px.imshow(pe.cpu().numpy().T, aspect='auto', color_continuous_scale='RdBu_r',
          width = 1100, height=500)
In [10]:
dist = pe @ pe.T
fig = px.imshow(dist.cpu().numpy(), color_continuous_scale='Viridis',
                width=700, height=700,
                title='Dot Product of Positional Encodings')
fig.show()
px.line(x=np.arange(0, n), y=dist[10].cpu().numpy(), width=800, height=400,
        title='Dot Product of Positional Encodings for Position 200')

Tokenization

In [11]:
#!pip install transformers
In [12]:
from transformers import AutoTokenizer

# Load the Qwen tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")
tokenizer.vocab_size
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning:


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).

tokenizer_config.json: 0.00B [00:00, ?B/s]
vocab.json: 0.00B [00:00, ?B/s]
merges.txt: 0.00B [00:00, ?B/s]
tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]
Out[12]:
151643
In [13]:
tokenizer.encode("Hello, how are you?")
Out[13]:
[9707, 11, 1246, 525, 498, 30]

Byte Pair Encoding

The Byte Pair Encoding is fairly simple to implement. We start by splitting each word into its characters, appending a special end-of-word symbol </w> to the end of each word. Then, we repeatedly find the most common adjacent pair of symbols across all words and merge them into a new symbol. This process is repeated for a specified number of merges.

Once we have learned the merges, we can use them to tokenize new words by applying the merges in order until no more merges can be applied.

In [14]:
from typing import Dict, List, Tuple
from collections import Counter

EOW = "</w>"
In [15]:
def build_initial_vocab(corpus: str) -> Dict[str, int]:
    """
    Create an initial vocabulary consisting of all single characters
    observed in the corpus PLUS the EOW marker. IDs are assigned
    deterministically (sorted).

    Returns:
        vocab: dict mapping token string -> integer id
    """
    import string
    ALL_ASCII = set(string.ascii_letters + string.digits + string.punctuation)
    ALL_ASCII.add(EOW)  # include the end-of-word symbol
    corpus_set = set(corpus) - set(' \n\t\r')  # exclude whitespace

    return {tok: i for i, tok in enumerate(corpus_set | ALL_ASCII)}

vocab = build_initial_vocab("CS189 is CS.")
print(vocab)
{'3': 0, 'r': 1, '4': 2, 'Z': 3, 'H': 4, '!': 5, ';': 6, 'W': 7, 'w': 8, 'f': 9, 'M': 10, 'C': 11, '/': 12, '0': 13, ',': 14, 'd': 15, 'L': 16, 'y': 17, '~': 18, '[': 19, '.': 20, 'K': 21, 'G': 22, 'U': 23, '{': 24, 'h': 25, 'm': 26, 'Y': 27, '(': 28, '8': 29, '#': 30, 'B': 31, 'A': 32, 'c': 33, 'N': 34, '7': 35, "'": 36, '6': 37, 'T': 38, 'g': 39, '&': 40, 'F': 41, 'R': 42, 'n': 43, '+': 44, 'q': 45, 's': 46, '1': 47, ']': 48, '-': 49, '2': 50, 'P': 51, '9': 52, 'z': 53, 'a': 54, 'o': 55, '"': 56, 'k': 57, 'b': 58, 'X': 59, '%': 60, 'V': 61, '=': 62, '$': 63, '>': 64, 'v': 65, ':': 66, 'I': 67, '@': 68, '`': 69, '</w>': 70, 'J': 71, 't': 72, 'S': 73, 'e': 74, 'D': 75, '}': 76, '*': 77, 'u': 78, '?': 79, 'j': 80, 'Q': 81, 'i': 82, '<': 83, 'x': 84, 'E': 85, '_': 86, 'l': 87, 'p': 88, 'O': 89, '\\': 90, '^': 91, ')': 92, '5': 93, '|': 94}
In [16]:
def corpus_to_char_seq_with_eow(corpus: str) -> List[str]:
    """
    Convert the whole corpus into a single flat sequence of symbols.
    Each word contributes its characters followed by the end-of-word marker `</w>`.
    
    Example:
        corpus = "low lower"
        returns: ['l','o','w','</w>','l','o','w','e','r','</w>']
    
    Why this matters:
    - We treat the corpus as *one long list* (not a list of per-word lists),
      which is sometimes more convenient for teaching and for demonstrating
      the role of `</w>` in preventing merges across word boundaries.
    """
    seq: List[str] = []
    for word in corpus.split():
        seq.extend(list(word))
        seq.append(EOW)
    return seq

print(corpus_to_char_seq_with_eow("CS189 is great!"))
['C', 'S', '1', '8', '9', '</w>', 'i', 's', '</w>', 'g', 'r', 'e', 'a', 't', '!', '</w>']
In [17]:
def count_pair_frequencies(seq: List[str]) -> Counter:
    """
    Count frequencies of adjacent symbol pairs over the *flat* sequence.
    
    Boundary rule:
    - We *disallow* pairs that START with `</w>` because that would cross a 
      word boundary on merge (i.e., merging `</w>` with the next word's first
      character). We DO allow pairs that END with `</w>` (e.g., ('w','</w>')),
      which forms tokens like 'w</w>' and is standard in BPE.
    
    Returns:
        A Counter mapping (left_symbol, right_symbol) -> count.
    """
    pair_counts = Counter()
    for i in range(len(seq) - 1):
        left, right = seq[i], seq[i + 1]
        if left.endswith(EOW): # This pair would cross a word boundary; skip it.
            continue
        pair_counts[(left, right)] += 1
    return pair_counts

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
print(pair_freqs)
Counter({('C', 'S'): 2, ('S', '1'): 1, ('1', '8'): 1, ('8', '9'): 1, ('9', '</w>'): 1, ('i', 's'): 1, ('s', '</w>'): 1, ('S', '.'): 1, ('.', '</w>'): 1})
In [18]:
def merge_pair_in_sequence(seq: List[str], pair: Tuple[str, str]) -> List[str]:
    """
    Perform a single merge of the given pair across the flat sequence.
    Invariant:
    - Never merge if the left symbol ends with `</w>` 
       (prevents crossing word boundaries).
    - Scans left-to-right and uses a simple skip mechanic to avoid overlapping merges.
    """
    a, b = pair
    merged_token = a + b
    new_seq: List[str] = []
    i = 0
    n = len(seq)
    while i < n:
        if i < n - 1 and seq[i] == a and seq[i + 1] == b and seq[i] != EOW:
            new_seq.append(merged_token)
            i += 2  # skip the merged pair
        else:
            new_seq.append(seq[i])
            i += 1
    return new_seq

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
pair, freq = pair_freqs.most_common(1)[0]
print("Merging pair:", pair, "with frequency:", freq)
new_seq = merge_pair_in_sequence(seq, pair)
print(new_seq)
Merging pair: ('C', 'S') with frequency: 2
['CS', '1', '8', '9', '</w>', 'i', 's', '</w>', 'CS', '.', '</w>']
In [19]:
from tqdm import tqdm  # pip install tqdm

def learn_bpe_merges(corpus: str, num_merges: int = 1000, min_frequency: int = 2) -> Tuple[List[Tuple[str, str]], dict]:
    """
    Learn BPE merge rules from the corpus by repeatedly finding the most frequent
    adjacent pair and merging it, subject to the boundary rule.

    Args:
        corpus: Raw text (spaces separate words).
        num_merges: Maximum number of merges to learn.
        min_frequency: Stop when the most frequent pair occurs fewer than this.

    Returns:
        merges: A list of (left_symbol, right_symbol) in the order they were learned.
        vocab: Final vocabulary mapping token -> id
    """
    seq = corpus_to_char_seq_with_eow(corpus)
    merges: List[Tuple[str, str]] = []
    vocab = build_initial_vocab(corpus)
    next_id = max(vocab.values()) + 1

    # Wrap the merge loop in a tqdm progress bar
    progress = tqdm(range(num_merges), desc="Learning BPE merges", ncols=80)

    for step in progress:
        pair_counts = count_pair_frequencies(seq)
        if not pair_counts:
            progress.set_postfix_str("done (no pairs left)")
            break
        (best_pair, best_count) = pair_counts.most_common(1)[0]
        if best_count < min_frequency:
            progress.set_postfix_str(f"stopped (min freq < {min_frequency})")
            break

        # Merge and update structures
        seq = merge_pair_in_sequence(seq, best_pair)
        merges.append(best_pair)
        new_token = best_pair[0] + best_pair[1]
        if new_token not in vocab:
            vocab[new_token] = next_id
            next_id += 1

        # Update the tqdm progress bar info
        progress.set_postfix_str(f"merge {best_pair} ({best_count})")

    progress.close()
    return merges, vocab

corpus = "This is the best CS class. This is CS 189."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
print("Learned merges:", merges)
print("Vocabulary:", vocab)
Learning BPE merges:   7%| | 7/100 [00:00<00:00, 1604.73it/s, stopped (min freq 
Learned merges: [('i', 's'), ('is', '</w>'), ('T', 'h'), ('Th', 'is</w>'), ('C', 'S'), ('CS', '</w>'), ('.', '</w>')]
Vocabulary: {'3': 0, 'r': 1, '4': 2, 'Z': 3, 'H': 4, '!': 5, ';': 6, 'W': 7, 'w': 8, 'f': 9, 'M': 10, 'C': 11, '/': 12, '0': 13, ',': 14, 'd': 15, 'L': 16, 'y': 17, '~': 18, '[': 19, '.': 20, 'K': 21, 'G': 22, 'U': 23, '{': 24, 'h': 25, 'm': 26, 'Y': 27, '(': 28, '8': 29, '#': 30, 'B': 31, 'A': 32, 'c': 33, 'N': 34, '7': 35, "'": 36, '6': 37, 'T': 38, 'g': 39, '&': 40, 'F': 41, 'R': 42, 'n': 43, '+': 44, 'q': 45, 's': 46, '1': 47, ']': 48, '-': 49, '2': 50, 'P': 51, '9': 52, 'z': 53, 'a': 54, 'o': 55, '"': 56, 'k': 57, 'b': 58, 'X': 59, '%': 60, 'V': 61, '=': 62, '$': 63, '>': 64, 'v': 65, ':': 66, 'I': 67, '@': 68, '`': 69, '</w>': 70, 't': 71, 'J': 72, 'e': 73, 'S': 74, 'D': 75, '}': 76, '*': 77, 'u': 78, '?': 79, 'j': 80, 'Q': 81, 'i': 82, '<': 83, 'x': 84, 'E': 85, '_': 86, 'l': 87, 'p': 88, 'O': 89, '\\': 90, '^': 91, ')': 92, '5': 93, '|': 94, 'is': 95, 'is</w>': 96, 'Th': 97, 'This</w>': 98, 'CS': 99, 'CS</w>': 100, '.</w>': 101}

In [20]:
def bpe_encode(text: str, merges: List[Tuple[str, str]], vocab: Dict[str, int]) -> List[int]:
    """
    Encode a string into token IDs:
      1) Convert text -> flat char+EOW sequence
      2) Apply learned merges in order
      3) Map final tokens to IDs via vocab

    Note:
    - This simple teaching encoder applies merges globally; it assumes the
      learned merges were derived from a similar distribution (your corpus).
    - For speed, production systems use a 'rank' map and greedy longest-match;
      here we stick to the clearest didactic approach.
    """
    progress = tqdm(range(len(merges)), desc="Applying BPE merges", ncols=80)
    seq = corpus_to_char_seq_with_eow(text)
    for a, b in merges:
        seq = merge_pair_in_sequence(seq, (a, b))
        progress.update(1)
    progress.close()
    return seq, [vocab[tok] for tok in seq]

corpus = "This is the best CS class. This is CS 189 the best class."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
encoded_seq, token_ids = bpe_encode("CS 189 is the   best   class.", merges, vocab)
print("Encoded sequence:", encoded_seq)
print("Token IDs:", token_ids)
Learning BPE merges:  19%|▏| 19/100 [00:00<00:00, 1672.65it/s, stopped (min freq
Applying BPE merges: 100%|██████████████████| 19/19 [00:00<00:00, 165335.63it/s]
Encoded sequence: ['CS</w>', '1', '8', '9', '</w>', 'is</w>', 'the</w>', 'best</w>', 'class.</w>']
Token IDs: [107, 47, 29, 52, 70, 96, 101, 105, 113]

In [21]:
def bpe_decode(token_ids: List[int], vocab: Dict[str, int]) -> str:
    """
    Decode token IDs back to text by inverting the vocab and then
    removing EOW markers to re-insert spaces.

    Rules:
    - Tokens that END with EOW represent end-of-word units.
      We strip the trailing `</w>` and insert a space.
    - Other tokens are just literal substrings inside a word.

    Caveat:
    - Because we concatenated strings to form merged tokens, decoding simply
      concatenates their surfaces; then we rely on `</w>` to restore spaces.
    """
    inv_vocab = {i: t.replace(EOW, " ") for t, i in vocab.items()}
    out_words: List[str] = []
    buf = [inv_vocab[tid] for tid in token_ids]
    return "".join(buf).strip()
In [22]:
decoded_text = bpe_decode(token_ids, vocab)
print(f"Decoded text: \"{decoded_text}\"")
Decoded text: "CS 189 is the best class."

Implementing the Decoder Transformer for Generative Pre-training

Getting the Data

In [23]:
import os
if not os.path.exists("shakespeare.txt"):
    print('downloading corpus...')
    import requests
    url = "https://www.gutenberg.org/cache/epub/100/pg100.txt"
    response = requests.get(url)
    shakespeare_corpus = response.text
    with open("shakespeare.txt", "w") as f:
        f.write(shakespeare_corpus)
else:
    print('loading cached file...')
    with open("shakespeare.txt", "r") as f:
        shakespeare_corpus = f.read()
print(f"Corpus length: {len(shakespeare_corpus)} characters") 
print(shakespeare_corpus[:1000])
downloading corpus...
Corpus length: 5575062 characters
The Project Gutenberg eBook of The Complete Works of William Shakespeare
    
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
of the Project Gutenberg License included with this ebook or online
at www.gutenberg.org. If you are not located in the United States,
you will have to check the laws of the country where you are located
before using this eBook.

Title: The Complete Works of William Shakespeare

Author: William Shakespeare

Release date: January 1, 1994 [eBook #100]
                Most recently updated: August 24, 2025

Language: English



*** START OF THE PROJECT GUTENBERG EBOOK THE COMPLETE WORKS OF WILLIAM SHAKESPEARE ***




The Complete Works of William Shakespeare

by William Shakespeare




                    Contents

    THE SONNETS
    ALL’S WELL THAT ENDS WELL
    

Byte Pair Encoding

In [24]:
# if not os.path.exists("bpe_state.pkl"):
#     print('learning BPE merges on Shakespeare corpus...')
#     merges, vocab = learn_bpe_merges(shakespeare_corpus, 
#                                      num_merges=200, 
#                                      min_frequency=2)
#     with open("bpe_state.pkl", "wb") as f:
#         pickle.dump((merges, vocab), f)
# else:
#     print("loading cached BPE state...")
#     with open("bpe_state.pkl", "rb") as f:
#         merges, vocab = pickle.load(f)
# print("Learned merges:", merges)
# print("Vocabulary:", vocab)
# vocab_size = len(vocab)
# print("Vocabulary size:", vocab_size)
In [25]:
# if not os.path.exists("encoded_text_ids.pkl"):
#     print('encoding Shakespeare corpus...')
#     encoded_seq, token_ids = encode(shakespeare_corpus, merges, vocab)
#     with open("encoded_text_ids.pkl", "wb") as f:
#         pickle.dump((encoded_seq,token_ids), f)
# else:
#     print("loading cached encoded text...")
#     with open("encoded_text_ids.pkl", "rb") as f:
#         encoded_seq, token_ids = pickle.load(f)
# print("Encoded sequence length:", len(encoded_seq))
# corpus_tokens = torch.tensor(token_ids, dtype=torch.long, device=device)

# def encode(text: str) -> torch.Tensor:
#     _, token_ids = bpe_encode(text, merges, vocab)
#     return torch.tensor(token_ids, dtype=torch.long, device=device)

# def decode(token_ids: torch.Tensor) -> str:
#     return bpe_decode(token_ids.tolist(), vocab)

# tok = encode("To be, or not to be, that is the question.")
# decode(tok)

Word Encoding

In [26]:
import re
from collections import Counter

def split_words(corpus: str) -> List[str]:
    """
    Break the corpus into words using a regex that matches word characters.
    """
    pattern = r'\b\w+\b'
    corpus = re.sub(r'[._,!?;"\'`()\[\]{}<>]', '', corpus.lower())
    return re.findall(pattern, corpus.lower())

words = split_words(shakespeare_corpus)
# counter = Counter(words)
# vocab_set = {tok for tok, cnt in counter.items() if cnt > 1}
vocab_set = set(words)
vocab = {word: i for i, word in enumerate(sorted(vocab_set), start = 1)}
vocab["<unknown>"] = 1
inv_vocab = {i: word for word, i in vocab.items()}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)


def encode(text: str):
    """
    Encode a string into token IDs using the provided vocabulary.
    Unknown words are mapped to the ID for <unknown>.
    """
    words = split_words(text)
    return torch.tensor(
        [vocab.get(word, 1) for word in words],
        dtype=torch.long, device=device)

def decode(tokens: torch.Tensor) -> str:
    """
    Decode token IDs back to text by inverting the vocabulary.
    """
    words = [inv_vocab.get(t.item(), "<error>") for t in tokens]
    return " ".join(words)

corpus_tokens = encode(shakespeare_corpus)
print("Length of token IDs:", len(corpus_tokens))
decode(encode("to be or not to be that is the question tokenizer"))
Vocabulary size: 25119
Length of token IDs: 992077
Out[26]:
'to be or not to be that is the question <unknown>'

Data Preparation

In [27]:
N = corpus_tokens.shape[0]
seq_length = 64
split_ratio = 0.90
seed = 189
x = corpus_tokens[:(N - (N % seq_length) + 1)]
y = x[1:].reshape(-1, seq_length)
x = x[:-1].reshape(-1, seq_length)


from torch.utils.data import random_split, TensorDataset
dataset = TensorDataset(x, y)
generator = torch.Generator().manual_seed(seed)
training_data, validation_data = random_split(
    dataset, [split_ratio, 1 - split_ratio], 
    generator=generator) 
print("training contexts", len(training_data))
print("validation contexts", len(validation_data))
training_data[0]
training contexts 13951
validation contexts 1550
Out[27]:
(tensor([14864, 21972, 19505, 20758,  1169, 19278, 10750,  3106,  8358,  3024,
          7282, 14274,  9577, 22082, 24681,  1169, 12777, 11447, 12855, 11907,
         21880, 21972, 19505,  8635, 23586, 21876, 21470, 15025,   449, 14747,
          2166, 21869, 14556, 14572,  7882,  8134, 14794, 14572,  7882, 19499,
          9993,  9970, 24472, 10466, 10524, 15491, 22710, 22178, 15092, 10588,
         13003,  4518, 22178, 20425, 22178, 10750,  1169, 11250, 19549,  2112,
          1538,  9743,  1538, 19549], device='cuda:0'),
 tensor([21972, 19505, 20758,  1169, 19278, 10750,  3106,  8358,  3024,  7282,
         14274,  9577, 22082, 24681,  1169, 12777, 11447, 12855, 11907, 21880,
         21972, 19505,  8635, 23586, 21876, 21470, 15025,   449, 14747,  2166,
         21869, 14556, 14572,  7882,  8134, 14794, 14572,  7882, 19499,  9993,
          9970, 24472, 10466, 10524, 15491, 22710, 22178, 15092, 10588, 13003,
          4518, 22178, 20425, 22178, 10750,  1169, 11250, 19549,  2112,  1538,
          9743,  1538, 19549, 18805], device='cuda:0'))
In [28]:
print(decode(training_data[0][0]))
print(decode(training_data[0][1]))
now thou shalt stay and see her bright eyes break each morning gainst thy window and let in life into thee thou shalt feed upon the sweetness of a noble beauty that nature ne er exceeded nor ne er shall good gods what happiness has palamon twenty to one he ll come to speak to her and if she be as gentle as she
thou shalt stay and see her bright eyes break each morning gainst thy window and let in life into thee thou shalt feed upon the sweetness of a noble beauty that nature ne er exceeded nor ne er shall good gods what happiness has palamon twenty to one he ll come to speak to her and if she be as gentle as she s

Attention

In [29]:
def scaled_dot_product_attention(Q: torch.Tensor, 
                                 K: torch.Tensor, 
                                 V: torch.Tensor, 
                                 mask=None):
    """
    Q: matrix of shape (B, N, d_k)
    K: matrix of shape (B, N, d_k)
    V: matrix of shape (B, N, d_v)
    mask: boolean matrix of shape (1, N, N). Values where mask is True will be INCLUDED
    """
    (B, N, d_k) = Q.shape
    (_, _, d_v) = V.shape
    K_T = K.transpose(-2, -1) # (B, d_k, N)
    dot_product = (Q @ K_T) / (d_k ** 0.5) # (B, N, N)
    
    if mask is not None:
        dot_product = dot_product.masked_fill(mask.logical_not(), float('-inf'))

    attention = F.softmax(dot_product, dim=-1) # (B, N, N)
    return attention @ V  # (B, N, N) * (B, N, d_v) = (B, N, d_v)
In [30]:
_mask_cache = {}
def get_mask_with_cache(N, device):
    """
    Returns a lower triangular mask of shape (1, N, N) to be used for masked attention.
    """
    if N not in _mask_cache:
        _mask_cache[N] = torch.ones(
            (N, N), dtype=torch.bool, 
            device=device).tril().unsqueeze(0)  
    return _mask_cache[N] #  (1, N, N)

class MaskedAttentionHead(nn.Module):
    def __init__(self, d_model=512, d_v=512, d_k=64):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        self.W_k = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_q = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_v = nn.Linear(self.d_model, self.d_v, bias = False)


    def forward(self, x):
        """
        x is the input to use for the queries, keys, and values
        encoder_output is the output from the encoder (used for cross-attention)
        mask is the mask to use for the attention
        """
        (B, N, _) = x.shape
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        mask = get_mask_with_cache(N, device=x.device)
        values = scaled_dot_product_attention(Q, K, V, mask=mask) 
        # ##  more efficient implementation:
        # ##  Need to unsqueeze the head dimension for F.scaled_dot_product_attention
        # values = F.scaled_dot_product_attention(
        #     query=Q, key=K, value=V, 
        #     attn_mask=mask)
        return values
In [31]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, num_heads=8, d_model=512, d_k=64):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_model // num_heads
        self.attention_heads = nn.ModuleList(
            [
                MaskedAttentionHead(d_model=self.d_model, d_v = self.d_v, d_k=self.d_k)
                for _ in range(self.num_heads)
            ]
        )
        # Projection
        self.W_out = nn.Linear(self.num_heads * self.d_v, self.d_model) 
        

    def forward(self, x):
        (B, N, _) = x.shape
        assert(x.shape == (B, N, self.d_model))
        head_outputs = [head(x) for head in self.attention_heads]
        for head_out in head_outputs:
            assert(head_out.shape == (B, N, self.d_v))
        concatenated = torch.cat(head_outputs, dim=-1)
        assert(concatenated.shape == (B, N, self.num_heads * self.d_v))
        out = self.W_out(concatenated)
        return out

Testing the Layer with Masked Multi-Head Attention:

In [32]:
emb = nn.Embedding(vocab_size, 512).to(device)
layer = MaskedMultiHeadAttention(7, 512, 64).to(device)
In [33]:
batch_size = 7
x, y = training_data[:batch_size]
emb(x).shape
Out[33]:
torch.Size([7, 64, 512])
In [34]:
layer(emb(x)).shape
Out[34]:
torch.Size([7, 64, 512])

Decoder Architecture

In [35]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_k=64, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ffn = 4 * self.d_model
        self.d_k = d_k

        self.dropout = nn.Dropout(dropout)

        self.mh_attention = MaskedMultiHeadAttention(
            num_heads=self.num_heads, 
            d_model=self.d_model,
            d_k=self.d_k)
        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_ffn),
            nn.ReLU(),
            self.dropout,
            nn.Linear(self.d_ffn, self.d_model),
        )
        self.layernorm1 = nn.LayerNorm(self.d_model)
        self.layernorm2 = nn.LayerNorm(self.d_model)
        
        
    
    def forward(self, x):
        mh = self.mh_attention(self.layernorm1(x)) # Prenorm
        mh = self.dropout(mh)
        x = x + mh
        ffn = self.ffn(self.layernorm2(x)) # Prenorm
        ffn = self.dropout(ffn)
        return x
In [36]:
block = DecoderBlock(d_model=512, num_heads=8, d_k=64, dropout=0.1).to(device)
block(emb(x)).shape
Out[36]:
torch.Size([7, 64, 512])
In [37]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, max_len: int = 1024, L: float = 10000.0):
        """
        Sinusoidal positional encoding as in 'Attention is All You Need'.
        """
        super().__init__()

        pos = torch.zeros(max_len, d_model, dtype=torch.float32)
        positions = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
        div_terms = L ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
        quotient = positions / div_terms
        pos[:, 0::2] = torch.sin(quotient)  # even indices
        pos[:, 1::2] = torch.cos(quotient)  # odd indices
        # Register as non-parameter buffer so it moves with the module
        self.register_buffer("pos", pos)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        # pos: (1, seq_len, d_model) broadcasts along batch dimension
        return x + self.pos[:seq_len].unsqueeze(0)
In [38]:
class TransformerDecoderOnly(nn.Module):
    def __init__(self, 
                 max_length=1024,
                 vocab_size=6000, 
                 d_model=512,
                 d_k=64,
                 num_layers=6, 
                 num_heads=8, 
                 dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_k = d_k
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.layers = nn.Sequential()
        self.layers.append(
            nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        )
        self.layers.append(
            PositionalEncoding(d_model=d_model, max_len=max_length, L=10000)
        )
        for _ in range(num_layers):
            self.layers.append(
                DecoderBlock(d_model=d_model, num_heads=num_heads, 
                             d_k=self.d_k, dropout=self.dropout)
            )
        self.layers.append(
            nn.Linear(in_features=d_model, out_features=vocab_size)
        )
        
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters())
    
    def forward(self, x):
        return self.layers(x)
    
    def generate(self, x, max_new_tokens):
        """
        x: (B, N) tensor of input token IDs
        max_new_tokens: number of tokens to generate
        """
        self.eval()
        with torch.no_grad():
            B, N = x.shape
            for _ in range(max_new_tokens):
                x = x[:, -seq_length:]  # crop to last seq_length tokens
                logits = self.forward(x)  # (B, N, vocab_size)
                next_token_logits = logits[:, -1, :]  # (B, vocab_size)
                next_token_probs = F.softmax(next_token_logits, dim=-1)  # (B, vocab_size)
                next_tokens = torch.multinomial(next_token_probs, num_samples=1)  # (B, 1)
                x = torch.cat([x, next_tokens], dim=1)  # (B, N+1)
        return x
In [39]:
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=256, d_k=16, num_layers=6, num_heads=8, 
    dropout=0.1).to(device)

print(model)
print("Number of parameters:", model.num_parameters()/1e6, "million")
TransformerDecoderOnly(
  (layers): Sequential(
    (0): Embedding(25119, 256)
    (1): PositionalEncoding()
    (2): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (3): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (4): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (5): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (6): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (7): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (8): Linear(in_features=256, out_features=25119, bias=True)
  )
)
Number of parameters: 17.226783 million
In [40]:
decode(model.generate(encode("to be or not to be").unsqueeze(0), max_new_tokens=20)[0])
Out[40]:
'to be or not to be unparalleled yorkshire revenging types worshipp descension statesman addition reedy festinate camels unlink attain captainship agitation granted righteous sparrows grace appendix'

Training Loop

In [41]:
def batch_cross_entropy(pred, y):
    # flatten the batch into a single dimension and 
    # compute cross-entropy
    return F.cross_entropy(pred.view(-1, vocab_size), y.view(-1))
In [42]:
batch_cross_entropy(model(x), y)
Out[42]:
tensor(10.3724, device='cuda:0', grad_fn=<NllLossBackward0>)
In [43]:
def minibatch_gd(model, loss_fn, 
                 training_data,
                 batch_size, 
                 nsteps, 
                 learning_rate,
                 visualizer=None,
                 weight_decay=1e-4):
    generator = torch.Generator()
    generator.manual_seed(189)
    loader = DataLoader(training_data, 
                        batch_size=batch_size, 
                        shuffle=True, # shuffles each epoch
                        generator=generator)
    
    # Define the optimizer (this is the update rule)
    # Alternatively, you can use Adam optimizer
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate, weight_decay=weight_decay)
    model.train() # set model to training mode (important for dropout/batchnorm)
    step = 0
    # Loop through the steps
    iter_loader = iter(loader)
    for step in tqdm(range(nsteps)):
        # Get the next batch of data
        try:
            x, t = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            x, t = next(iter_loader)
        # Zero the gradients to start the next step
        optimizer.zero_grad()
        # Compute prediction and loss
        pred = model(x)
        loss = loss_fn(pred, t)
        tr_loss = loss.item()
        # Backpropagation (compute the gradient)
        loss.backward()
        # Update the parameters using the optimizer's update rule
        optimizer.step()
        # Visualize the model (if a visualizer function is provided)
        if visualizer is not None:
            model.eval() # disable dropout/batchnorm
            with torch.no_grad():
                visualizer(step, model, loss_fn, tr_loss)
            model.train()
In [44]:
class LossVisualizer:
    def __init__(self, loss_fig, validation_data):
        self.loss_fig = loss_fig
        self.val_loader = DataLoader(validation_data, 
                                     batch_size=32, 
                                     shuffle=False)
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []

    def reset(self):
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = []
            self.loss_fig.data[0].y = []
            self.loss_fig.data[1].x = []
            self.loss_fig.data[1].y = []
    
    def __call__(self, epoch, model, loss_fn, loss_tr):
        model.eval()
        with torch.no_grad():
            losses = []
            for x_val, t_val in self.val_loader:
                loss_val = loss_fn(model(x_val), t_val).item()
                losses.append(loss_val)
            loss_val = np.mean(losses)
        self.epochs.append(epoch)
        self.losses_val.append(loss_val)
        self.losses_tr.append(loss_tr)
        print("training loss:", loss_tr, "validation loss:", loss_val)
        # Visualization Code
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = self.epochs
            self.loss_fig.data[0].y = self.losses_val
            self.loss_fig.data[1].x = self.epochs
            self.loss_fig.data[1].y = self.losses_tr

        model.train()
In [45]:
loss_fig = go.FigureWidget()
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Val. Loss'))
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Train. Loss'))
visualizer = LossVisualizer(loss_fig, validation_data)
display(loss_fig)
FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'Val. Loss',
              'type': 'scatter',
              'uid': '5ec38f84-7c4d-435f-8fb7-56afe996d547',
              'x': [0],
              'y': [0]},
             {'mode': 'lines',
              'name': 'Train. Loss',
              'type': 'scatter',
              'uid': '1f185c4c-8c4e-49d1-a8b2-e305ae279857',
              'x': [0],
              'y': [0]}],
    'layout': {'template': '...'}
})
In [48]:
visualizer.reset()
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=1024, d_k=32, num_layers=2, num_heads=8, 
    dropout=0.0).to(device)

#model = torch.compile(model)

# model = TransformerDecoderOnly(
#     max_length=seq_length, 
#     vocab_size=vocab_size, 
#     d_model=128, d_k=32, num_layers=8, num_heads=8, 
#     dropout=0.1).to(device)

minibatch_gd(
    model=model,
    loss_fn=batch_cross_entropy,
    training_data=training_data,
    batch_size=128,
    nsteps=200,
    learning_rate=3e-4,
    weight_decay=1e-7,
    visualizer=visualizer
)
  0%|          | 1/200 [00:00<02:57,  1.12it/s]
training loss: 10.361719131469727 validation loss: 9.929148265293666
  1%|          | 2/200 [00:01<02:53,  1.14it/s]
training loss: 9.939773559570312 validation loss: 9.46023094410799
  2%|▏         | 3/200 [00:02<02:50,  1.15it/s]
training loss: 9.456949234008789 validation loss: 8.856558429951571
  2%|▏         | 4/200 [00:03<02:49,  1.15it/s]
training loss: 8.850252151489258 validation loss: 8.335775414291692
  2%|▎         | 5/200 [00:04<02:48,  1.16it/s]
training loss: 8.29616641998291 validation loss: 7.993326848866988
  3%|▎         | 6/200 [00:05<02:47,  1.16it/s]
training loss: 7.9955973625183105 validation loss: 7.675127457599251
  4%|▎         | 7/200 [00:06<02:46,  1.16it/s]
training loss: 7.692054271697998 validation loss: 7.490796225411551
  4%|▍         | 8/200 [00:06<02:45,  1.16it/s]
training loss: 7.510016441345215 validation loss: 7.413140131502735
  4%|▍         | 9/200 [00:07<02:44,  1.16it/s]
training loss: 7.368573188781738 validation loss: 7.367588490855937
  5%|▌         | 10/200 [00:08<02:43,  1.16it/s]
training loss: 7.339788436889648 validation loss: 7.327013706674381
  6%|▌         | 11/200 [00:09<02:42,  1.16it/s]
training loss: 7.325085639953613 validation loss: 7.282361157086431
  6%|▌         | 12/200 [00:10<02:41,  1.16it/s]
training loss: 7.287446022033691 validation loss: 7.238322219070123
  6%|▋         | 13/200 [00:11<02:40,  1.16it/s]
training loss: 7.1515913009643555 validation loss: 7.2038816335249924
  7%|▋         | 14/200 [00:12<02:39,  1.17it/s]
training loss: 7.1626057624816895 validation loss: 7.165999344417027
  8%|▊         | 15/200 [00:12<02:38,  1.16it/s]
training loss: 7.131077766418457 validation loss: 7.128883332622294
  8%|▊         | 16/200 [00:13<02:37,  1.16it/s]
training loss: 7.13786506652832 validation loss: 7.104381113636251
  8%|▊         | 17/200 [00:14<02:36,  1.17it/s]
training loss: 7.147603511810303 validation loss: 7.085897280245411
  9%|▉         | 18/200 [00:15<02:36,  1.17it/s]
training loss: 7.146273612976074 validation loss: 7.067882323751644
 10%|▉         | 19/200 [00:16<02:35,  1.17it/s]
training loss: 7.116576194763184 validation loss: 7.046328213750099
 10%|█         | 20/200 [00:17<02:34,  1.16it/s]
training loss: 7.081804275512695 validation loss: 7.0230551447187155
 10%|█         | 21/200 [00:18<02:34,  1.16it/s]
training loss: 7.080620765686035 validation loss: 7.001611417653609
 11%|█         | 22/200 [00:18<02:33,  1.16it/s]
training loss: 7.038182735443115 validation loss: 6.987391355086346
 12%|█▏        | 23/200 [00:19<02:31,  1.16it/s]
training loss: 6.958187580108643 validation loss: 6.9772258875321365
 12%|█▏        | 24/200 [00:20<02:31,  1.17it/s]
training loss: 7.008291244506836 validation loss: 6.961357710312824
 12%|█▎        | 25/200 [00:21<02:30,  1.16it/s]
training loss: 6.9798808097839355 validation loss: 6.942648644350013
 13%|█▎        | 26/200 [00:22<02:29,  1.16it/s]
training loss: 6.94419002532959 validation loss: 6.927445985832993
 14%|█▎        | 27/200 [00:23<02:29,  1.16it/s]
training loss: 6.9823102951049805 validation loss: 6.912888059810716
 14%|█▍        | 28/200 [00:24<02:28,  1.16it/s]
training loss: 6.891851902008057 validation loss: 6.893563085672807
 14%|█▍        | 29/200 [00:24<02:27,  1.16it/s]
training loss: 6.838163375854492 validation loss: 6.876966379126724
 15%|█▌        | 30/200 [00:25<02:26,  1.16it/s]
training loss: 6.8766961097717285 validation loss: 6.8654066299905585
 16%|█▌        | 31/200 [00:26<02:25,  1.16it/s]
training loss: 6.861605167388916 validation loss: 6.852149145943778
 16%|█▌        | 32/200 [00:27<02:24,  1.16it/s]
training loss: 6.7996416091918945 validation loss: 6.837655349653595
 16%|█▋        | 33/200 [00:28<02:23,  1.16it/s]
training loss: 6.846062183380127 validation loss: 6.826757275328344
 17%|█▋        | 34/200 [00:29<02:23,  1.16it/s]
training loss: 6.7678632736206055 validation loss: 6.815719273625588
 18%|█▊        | 35/200 [00:30<02:22,  1.16it/s]
training loss: 6.793147087097168 validation loss: 6.804327663110227
 18%|█▊        | 36/200 [00:31<02:21,  1.16it/s]
training loss: 6.738554000854492 validation loss: 6.795834551052171
 18%|█▊        | 37/200 [00:31<02:20,  1.16it/s]
training loss: 6.736570358276367 validation loss: 6.788319529319296
 19%|█▉        | 38/200 [00:32<02:19,  1.16it/s]
training loss: 6.753721237182617 validation loss: 6.776969753966039
 20%|█▉        | 39/200 [00:33<02:18,  1.16it/s]
training loss: 6.767504692077637 validation loss: 6.764109650436713
 20%|██        | 40/200 [00:34<02:18,  1.15it/s]
training loss: 6.893495082855225 validation loss: 6.758010192793243
 20%|██        | 41/200 [00:35<02:17,  1.15it/s]
training loss: 6.77290678024292 validation loss: 6.749037012761953
 21%|██        | 42/200 [00:36<02:17,  1.15it/s]
training loss: 6.768361568450928 validation loss: 6.737954120246732
 22%|██▏       | 43/200 [00:37<02:16,  1.15it/s]
training loss: 6.690580368041992 validation loss: 6.730391210439254
 22%|██▏       | 44/200 [00:37<02:15,  1.15it/s]
training loss: 6.812774658203125 validation loss: 6.720891534065713
 22%|██▎       | 45/200 [00:38<02:15,  1.15it/s]
training loss: 6.717848777770996 validation loss: 6.709689646351094
 23%|██▎       | 46/200 [00:39<02:13,  1.15it/s]
training loss: 6.644081115722656 validation loss: 6.701014956649469
 24%|██▎       | 47/200 [00:40<02:13,  1.15it/s]
training loss: 6.717618465423584 validation loss: 6.692500007395842
 24%|██▍       | 48/200 [00:41<02:12,  1.14it/s]
training loss: 6.673792839050293 validation loss: 6.6824401933319715
 24%|██▍       | 49/200 [00:42<02:12,  1.14it/s]
training loss: 6.780778408050537 validation loss: 6.672941314930818
 25%|██▌       | 50/200 [00:43<02:11,  1.14it/s]
training loss: 6.703185081481934 validation loss: 6.665250885243318
 26%|██▌       | 51/200 [00:44<02:11,  1.14it/s]
training loss: 6.724969863891602 validation loss: 6.658390424689468
 26%|██▌       | 52/200 [00:44<02:10,  1.14it/s]
training loss: 6.611624717712402 validation loss: 6.650597085758132
 26%|██▋       | 53/200 [00:45<02:09,  1.13it/s]
training loss: 6.6266632080078125 validation loss: 6.642186018885399
 27%|██▋       | 54/200 [00:46<02:08,  1.13it/s]
training loss: 6.738619327545166 validation loss: 6.635702522433534
 28%|██▊       | 55/200 [00:47<02:07,  1.13it/s]
training loss: 6.580887794494629 validation loss: 6.6281601360866
 28%|██▊       | 56/200 [00:48<02:07,  1.13it/s]
training loss: 6.584505081176758 validation loss: 6.619589659632469
 28%|██▊       | 57/200 [00:49<02:06,  1.13it/s]
training loss: 6.646501541137695 validation loss: 6.612708743737668
 29%|██▉       | 58/200 [00:50<02:05,  1.13it/s]
training loss: 6.598003387451172 validation loss: 6.60520042691912
 30%|██▉       | 59/200 [00:51<02:03,  1.14it/s]
training loss: 6.551400184631348 validation loss: 6.597341722371627
 30%|███       | 60/200 [00:52<02:02,  1.14it/s]
training loss: 6.596179962158203 validation loss: 6.59074356118027
 30%|███       | 61/200 [00:52<02:02,  1.13it/s]
training loss: 6.584847450256348 validation loss: 6.583917608066481
 31%|███       | 62/200 [00:53<02:01,  1.13it/s]
training loss: 6.4984588623046875 validation loss: 6.576665099786252
 32%|███▏      | 63/200 [00:54<02:00,  1.14it/s]
training loss: 6.626906871795654 validation loss: 6.569135636699443
 32%|███▏      | 64/200 [00:55<01:59,  1.14it/s]
training loss: 6.555837154388428 validation loss: 6.562342857827946
 32%|███▎      | 65/200 [00:56<01:58,  1.14it/s]
training loss: 6.533874034881592 validation loss: 6.5565643602487995
 33%|███▎      | 66/200 [00:57<01:56,  1.15it/s]
training loss: 6.563194751739502 validation loss: 6.550891428577657
 34%|███▎      | 67/200 [00:58<01:56,  1.14it/s]
training loss: 6.503715515136719 validation loss: 6.543573175157819
 34%|███▍      | 68/200 [00:59<01:55,  1.15it/s]
training loss: 6.514548301696777 validation loss: 6.535555382164157
 34%|███▍      | 69/200 [00:59<01:54,  1.15it/s]
training loss: 6.554697036743164 validation loss: 6.5284194459720535
 35%|███▌      | 70/200 [01:00<01:53,  1.15it/s]
training loss: 6.534619331359863 validation loss: 6.5230604385843085
 36%|███▌      | 71/200 [01:01<01:52,  1.15it/s]
training loss: 6.466627597808838 validation loss: 6.517006436172797
 36%|███▌      | 72/200 [01:02<01:51,  1.15it/s]
training loss: 6.492819786071777 validation loss: 6.510690523653614
 36%|███▋      | 73/200 [01:03<01:50,  1.15it/s]
training loss: 6.507439136505127 validation loss: 6.503489834921701
 37%|███▋      | 74/200 [01:04<01:49,  1.15it/s]
training loss: 6.539058208465576 validation loss: 6.497100265658632
 38%|███▊      | 75/200 [01:05<01:48,  1.15it/s]
training loss: 6.465414047241211 validation loss: 6.490775984160754
 38%|███▊      | 76/200 [01:05<01:47,  1.15it/s]
training loss: 6.4486589431762695 validation loss: 6.484331160175557
 38%|███▊      | 77/200 [01:06<01:46,  1.15it/s]
training loss: 6.4612884521484375 validation loss: 6.478635593336456
 39%|███▉      | 78/200 [01:07<01:45,  1.16it/s]
training loss: 6.446281433105469 validation loss: 6.473315920148577
 40%|███▉      | 79/200 [01:08<01:44,  1.16it/s]
training loss: 6.434588432312012 validation loss: 6.467980958977524
 40%|████      | 80/200 [01:09<01:43,  1.16it/s]
training loss: 6.499963760375977 validation loss: 6.463095120021275
 40%|████      | 81/200 [01:10<01:42,  1.16it/s]
training loss: 6.494248867034912 validation loss: 6.459504156696553
 41%|████      | 82/200 [01:11<01:42,  1.15it/s]
training loss: 6.364786148071289 validation loss: 6.454968384334019
 42%|████▏     | 83/200 [01:12<01:41,  1.15it/s]
training loss: 6.4450578689575195 validation loss: 6.4510873191210685
 42%|████▏     | 84/200 [01:12<01:41,  1.15it/s]
training loss: 6.448891639709473 validation loss: 6.445479120526995
 42%|████▎     | 85/200 [01:13<01:40,  1.14it/s]
training loss: 6.4035844802856445 validation loss: 6.441236778181427
 43%|████▎     | 86/200 [01:14<01:39,  1.14it/s]
training loss: 6.477339267730713 validation loss: 6.438299616988824
 44%|████▎     | 87/200 [01:15<01:39,  1.14it/s]
training loss: 6.415297508239746 validation loss: 6.435391718027543
 44%|████▍     | 88/200 [01:16<01:38,  1.14it/s]
training loss: 6.430819988250732 validation loss: 6.431323148766342
 44%|████▍     | 89/200 [01:17<01:37,  1.14it/s]
training loss: 6.412506103515625 validation loss: 6.426331520080566
 45%|████▌     | 90/200 [01:18<01:36,  1.14it/s]
training loss: 6.434972763061523 validation loss: 6.420756680624826
 46%|████▌     | 91/200 [01:19<01:34,  1.15it/s]
training loss: 6.411623001098633 validation loss: 6.416116889642209
 46%|████▌     | 92/200 [01:19<01:33,  1.15it/s]
training loss: 6.3786115646362305 validation loss: 6.411592911700813
 46%|████▋     | 93/200 [01:20<01:32,  1.16it/s]
training loss: 6.417270660400391 validation loss: 6.407266937956518
 47%|████▋     | 94/200 [01:21<01:31,  1.15it/s]
training loss: 6.399571895599365 validation loss: 6.403549680904466
 48%|████▊     | 95/200 [01:22<01:30,  1.16it/s]
training loss: 6.377954959869385 validation loss: 6.399578201527498
 48%|████▊     | 96/200 [01:23<01:29,  1.16it/s]
training loss: 6.492927551269531 validation loss: 6.395020825522287
 48%|████▊     | 97/200 [01:24<01:28,  1.16it/s]
training loss: 6.409252166748047 validation loss: 6.390357270532725
 49%|████▉     | 98/200 [01:25<01:27,  1.16it/s]
training loss: 6.410403728485107 validation loss: 6.385524010171696
 50%|████▉     | 99/200 [01:25<01:27,  1.15it/s]
training loss: 6.511599540710449 validation loss: 6.381968644200539
 50%|█████     | 100/200 [01:26<01:26,  1.15it/s]
training loss: 6.407318592071533 validation loss: 6.378166646373515
 50%|█████     | 101/200 [01:27<01:25,  1.15it/s]
training loss: 6.364468097686768 validation loss: 6.374616914865922
 51%|█████     | 102/200 [01:28<01:24,  1.16it/s]
training loss: 6.477817535400391 validation loss: 6.371596628305864
 52%|█████▏    | 103/200 [01:29<01:23,  1.16it/s]
training loss: 6.411909103393555 validation loss: 6.368403979710171
 52%|█████▏    | 104/200 [01:30<01:23,  1.16it/s]
training loss: 6.30910062789917 validation loss: 6.364959317810682
 52%|█████▎    | 105/200 [01:31<01:22,  1.15it/s]
training loss: 6.369457721710205 validation loss: 6.3607834796516265
 53%|█████▎    | 106/200 [01:32<01:21,  1.15it/s]
training loss: 6.462528228759766 validation loss: 6.355573245457241
 54%|█████▎    | 107/200 [01:32<01:20,  1.16it/s]
training loss: 6.37962532043457 validation loss: 6.351249480734066
 54%|█████▍    | 108/200 [01:33<01:19,  1.16it/s]
training loss: 6.378788948059082 validation loss: 6.347711748006392
 55%|█████▍    | 109/200 [01:34<01:19,  1.15it/s]
training loss: 6.42195987701416 validation loss: 6.344782994717968
 55%|█████▌    | 110/200 [01:35<01:18,  1.15it/s]
training loss: 6.095822334289551 validation loss: 6.342060916277827
 56%|█████▌    | 111/200 [01:36<01:17,  1.15it/s]
training loss: 6.107515335083008 validation loss: 6.34109827936912
 56%|█████▌    | 112/200 [01:37<01:16,  1.14it/s]
training loss: 6.106581211090088 validation loss: 6.338152398868483
 56%|█████▋    | 113/200 [01:38<01:15,  1.15it/s]
training loss: 6.084316730499268 validation loss: 6.333792900552555
 57%|█████▋    | 114/200 [01:38<01:15,  1.14it/s]
training loss: 6.142121315002441 validation loss: 6.333415294180111
 57%|█████▊    | 115/200 [01:39<01:13,  1.15it/s]
training loss: 6.144359111785889 validation loss: 6.329886485119255
 58%|█████▊    | 116/200 [01:40<01:12,  1.15it/s]
training loss: 6.037834167480469 validation loss: 6.326496591373366
 58%|█████▊    | 117/200 [01:41<01:11,  1.15it/s]
training loss: 6.078401565551758 validation loss: 6.327183791569301
 59%|█████▉    | 118/200 [01:42<01:11,  1.15it/s]
training loss: 6.101098537445068 validation loss: 6.323671058732636
 60%|█████▉    | 119/200 [01:43<01:10,  1.15it/s]
training loss: 6.168275833129883 validation loss: 6.320113259918836
 60%|██████    | 120/200 [01:44<01:09,  1.15it/s]
training loss: 6.099752902984619 validation loss: 6.3197658013324345
 60%|██████    | 121/200 [01:45<01:08,  1.15it/s]
training loss: 6.145869255065918 validation loss: 6.314804826463972
 61%|██████    | 122/200 [01:45<01:07,  1.15it/s]
training loss: 6.151786804199219 validation loss: 6.311663053473648
 62%|██████▏   | 123/200 [01:46<01:07,  1.15it/s]
training loss: 6.152758598327637 validation loss: 6.3094333434591485
 62%|██████▏   | 124/200 [01:47<01:06,  1.15it/s]
training loss: 6.077962398529053 validation loss: 6.305732113974435
 62%|██████▎   | 125/200 [01:48<01:05,  1.15it/s]
training loss: 5.982882499694824 validation loss: 6.304498010752153
 63%|██████▎   | 126/200 [01:49<01:04,  1.15it/s]
training loss: 6.154467582702637 validation loss: 6.302595858671228
 64%|██████▎   | 127/200 [01:50<01:03,  1.15it/s]
training loss: 6.038270950317383 validation loss: 6.302015421341877
 64%|██████▍   | 128/200 [01:51<01:02,  1.15it/s]
training loss: 6.119851112365723 validation loss: 6.302074909210205
 64%|██████▍   | 129/200 [01:52<01:01,  1.15it/s]
training loss: 6.008480072021484 validation loss: 6.298991728802116
 65%|██████▌   | 130/200 [01:52<01:00,  1.15it/s]
training loss: 6.137060165405273 validation loss: 6.293819904327393
 66%|██████▌   | 131/200 [01:53<01:00,  1.15it/s]
training loss: 6.07685661315918 validation loss: 6.292368694227569
 66%|██████▌   | 132/200 [01:54<00:59,  1.15it/s]
training loss: 6.112682342529297 validation loss: 6.289677230679259
 66%|██████▋   | 133/200 [01:55<00:58,  1.14it/s]
training loss: 6.083937644958496 validation loss: 6.286848691045021
 67%|██████▋   | 134/200 [01:56<00:57,  1.14it/s]
training loss: 6.105483531951904 validation loss: 6.284933129135443
 68%|██████▊   | 135/200 [01:57<00:56,  1.14it/s]
training loss: 6.02985143661499 validation loss: 6.283226859812834
 68%|██████▊   | 136/200 [01:58<00:56,  1.14it/s]
training loss: 6.085984706878662 validation loss: 6.278696060180664
 68%|██████▊   | 137/200 [01:59<00:55,  1.14it/s]
training loss: 6.073261260986328 validation loss: 6.275632702574438
 69%|██████▉   | 138/200 [01:59<00:53,  1.15it/s]
training loss: 6.017243385314941 validation loss: 6.272366971385722
 70%|██████▉   | 139/200 [02:00<00:52,  1.15it/s]
training loss: 6.022629261016846 validation loss: 6.27057543579413
 70%|███████   | 140/200 [02:01<00:51,  1.15it/s]
training loss: 6.011186599731445 validation loss: 6.27148031701847
 70%|███████   | 141/200 [02:02<00:50,  1.16it/s]
training loss: 6.033587455749512 validation loss: 6.269955070651307
 71%|███████   | 142/200 [02:03<00:49,  1.16it/s]
training loss: 6.104765892028809 validation loss: 6.266687023396394
 72%|███████▏  | 143/200 [02:04<00:49,  1.16it/s]
training loss: 6.068141937255859 validation loss: 6.263849774185492
 72%|███████▏  | 144/200 [02:05<00:48,  1.16it/s]
training loss: 6.077755928039551 validation loss: 6.261609632141736
 72%|███████▎  | 145/200 [02:05<00:47,  1.16it/s]
training loss: 6.078707695007324 validation loss: 6.260631152561733
 73%|███████▎  | 146/200 [02:06<00:46,  1.16it/s]
training loss: 6.057844638824463 validation loss: 6.258478923719757
 74%|███████▎  | 147/200 [02:07<00:45,  1.16it/s]
training loss: 6.060548782348633 validation loss: 6.257633919618567
 74%|███████▍  | 148/200 [02:08<00:44,  1.16it/s]
training loss: 6.053906440734863 validation loss: 6.2561231827249335
 74%|███████▍  | 149/200 [02:09<00:44,  1.16it/s]
training loss: 6.057756423950195 validation loss: 6.253395129223259
 75%|███████▌  | 150/200 [02:10<00:43,  1.15it/s]
training loss: 6.016613006591797 validation loss: 6.251269593530772
 76%|███████▌  | 151/200 [02:11<00:42,  1.15it/s]
training loss: 6.069310188293457 validation loss: 6.249930294192567
 76%|███████▌  | 152/200 [02:11<00:41,  1.15it/s]
training loss: 6.057592391967773 validation loss: 6.248094733880491
 76%|███████▋  | 153/200 [02:12<00:40,  1.15it/s]
training loss: 6.055517673492432 validation loss: 6.245763856537488
 77%|███████▋  | 154/200 [02:13<00:39,  1.15it/s]
training loss: 5.981128692626953 validation loss: 6.2438530921936035
 78%|███████▊  | 155/200 [02:14<00:39,  1.15it/s]
training loss: 6.022617340087891 validation loss: 6.241560556450668
 78%|███████▊  | 156/200 [02:15<00:38,  1.15it/s]
training loss: 6.086309432983398 validation loss: 6.237579958779471
 78%|███████▊  | 157/200 [02:16<00:37,  1.16it/s]
training loss: 6.069828510284424 validation loss: 6.234373160770962
 79%|███████▉  | 158/200 [02:17<00:36,  1.16it/s]
training loss: 5.977034568786621 validation loss: 6.233016325502979
 80%|███████▉  | 159/200 [02:18<00:35,  1.16it/s]
training loss: 6.02243185043335 validation loss: 6.231708945060263
 80%|████████  | 160/200 [02:18<00:34,  1.16it/s]
training loss: 6.022791862487793 validation loss: 6.228585700599515
 80%|████████  | 161/200 [02:19<00:33,  1.16it/s]
training loss: 5.982759475708008 validation loss: 6.226058765333526
 81%|████████  | 162/200 [02:20<00:32,  1.16it/s]
training loss: 6.001495361328125 validation loss: 6.22391228773156
 82%|████████▏ | 163/200 [02:21<00:31,  1.16it/s]
training loss: 6.003325462341309 validation loss: 6.221862325862962
 82%|████████▏ | 164/200 [02:22<00:30,  1.16it/s]
training loss: 6.012252330780029 validation loss: 6.221921395282356
 82%|████████▎ | 165/200 [02:23<00:30,  1.16it/s]
training loss: 5.909038543701172 validation loss: 6.222556221241853
 83%|████████▎ | 166/200 [02:24<00:29,  1.16it/s]
training loss: 5.966092109680176 validation loss: 6.22004255956533
 84%|████████▎ | 167/200 [02:24<00:28,  1.16it/s]
training loss: 5.964709758758545 validation loss: 6.217550267978591
 84%|████████▍ | 168/200 [02:25<00:27,  1.17it/s]
training loss: 5.990265369415283 validation loss: 6.2139473642621725
 84%|████████▍ | 169/200 [02:26<00:26,  1.16it/s]
training loss: 5.978246212005615 validation loss: 6.212402859512641
 85%|████████▌ | 170/200 [02:27<00:25,  1.16it/s]
training loss: 5.977949142456055 validation loss: 6.213127720112703
 86%|████████▌ | 171/200 [02:28<00:24,  1.16it/s]
training loss: 6.002535820007324 validation loss: 6.210467143934601
 86%|████████▌ | 172/200 [02:29<00:24,  1.17it/s]
training loss: 6.001431941986084 validation loss: 6.210143994311897
 86%|████████▋ | 173/200 [02:30<00:23,  1.16it/s]
training loss: 6.107723236083984 validation loss: 6.207548559928427
 87%|████████▋ | 174/200 [02:30<00:22,  1.16it/s]
training loss: 5.944747447967529 validation loss: 6.205113342830113
 88%|████████▊ | 175/200 [02:31<00:21,  1.16it/s]
training loss: 5.999008655548096 validation loss: 6.20351285350566
 88%|████████▊ | 176/200 [02:32<00:20,  1.16it/s]
training loss: 5.943323135375977 validation loss: 6.200359597498057
 88%|████████▊ | 177/200 [02:33<00:19,  1.16it/s]
training loss: 5.962564468383789 validation loss: 6.19809255794603
 89%|████████▉ | 178/200 [02:34<00:18,  1.16it/s]
training loss: 5.994581699371338 validation loss: 6.1956459064872895
 90%|████████▉ | 179/200 [02:35<00:18,  1.16it/s]
training loss: 6.089620113372803 validation loss: 6.194755612587442
 90%|█████████ | 180/200 [02:36<00:17,  1.16it/s]
training loss: 5.90007209777832 validation loss: 6.194045397700096
 90%|█████████ | 181/200 [02:36<00:16,  1.16it/s]
training loss: 5.999163627624512 validation loss: 6.1943637789512165
 91%|█████████ | 182/200 [02:37<00:15,  1.16it/s]
training loss: 6.013882637023926 validation loss: 6.192244325365339
 92%|█████████▏| 183/200 [02:38<00:14,  1.16it/s]
training loss: 5.918767929077148 validation loss: 6.190416705851653
 92%|█████████▏| 184/200 [02:39<00:13,  1.17it/s]
training loss: 5.998378753662109 validation loss: 6.186842869739143
 92%|█████████▎| 185/200 [02:40<00:12,  1.17it/s]
training loss: 6.001276016235352 validation loss: 6.184115536358892
 93%|█████████▎| 186/200 [02:41<00:12,  1.16it/s]
training loss: 5.9881744384765625 validation loss: 6.182604663226069
 94%|█████████▎| 187/200 [02:42<00:11,  1.16it/s]
training loss: 6.0272016525268555 validation loss: 6.180895309058988
 94%|█████████▍| 188/200 [02:42<00:10,  1.16it/s]
training loss: 6.033154487609863 validation loss: 6.18033537572744
 94%|█████████▍| 189/200 [02:43<00:09,  1.15it/s]
training loss: 6.017482280731201 validation loss: 6.178543343835948
 95%|█████████▌| 190/200 [02:44<00:08,  1.16it/s]
training loss: 5.964412212371826 validation loss: 6.1761249425459885
 96%|█████████▌| 191/200 [02:45<00:07,  1.16it/s]
training loss: 6.065142631530762 validation loss: 6.174429805911317
 96%|█████████▌| 192/200 [02:46<00:06,  1.16it/s]
training loss: 5.94199275970459 validation loss: 6.173588704089729
 96%|█████████▋| 193/200 [02:47<00:06,  1.16it/s]
training loss: 5.937205791473389 validation loss: 6.170328422468536
 97%|█████████▋| 194/200 [02:48<00:05,  1.16it/s]
training loss: 5.961018085479736 validation loss: 6.168160039551404
 98%|█████████▊| 195/200 [02:49<00:04,  1.16it/s]
training loss: 5.929530143737793 validation loss: 6.168551522858289
 98%|█████████▊| 196/200 [02:49<00:03,  1.16it/s]
training loss: 5.922443389892578 validation loss: 6.167792475953394
 98%|█████████▊| 197/200 [02:50<00:02,  1.16it/s]
training loss: 5.891782283782959 validation loss: 6.164439503027468
 99%|█████████▉| 198/200 [02:51<00:01,  1.16it/s]
training loss: 5.946990966796875 validation loss: 6.16373381322744
100%|█████████▉| 199/200 [02:52<00:00,  1.16it/s]
training loss: 5.919736862182617 validation loss: 6.162640055831598
100%|██████████| 200/200 [02:53<00:00,  1.15it/s]
training loss: 5.924776077270508 validation loss: 6.159777845655169

In [49]:
decode(model.generate(encode("whether tis nobler").unsqueeze(0), max_new_tokens=20)[0])
Out[49]:
'whether tis nobler king our affection amongst our end eyebrows sends oblivion dead and let be at care foul my lord i is'
In [ ]: