In [1]:
import plotly.io as pio; pio.renderers.default = "notebook_connected"

Lecture 22 – CS 189, Fall 2025

This is a copy of lecture 21 with some minor modifications.

In [2]:
#!pip install -U plotly
In [3]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pickle
In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# device = "cuda" # Colab
device = "mps" # Mac with M1/M2
# device = "cpu" # Local CPU
In [5]:
# import plotly.io as pio
# pio.renderers.default = "vscode" # VSCode
# pio.renderers.default = "colab" # Colab support

Sinusoidal Embeddings

In [6]:
D = 6
n = 16
L = 1000
torch.arange(0, D, 2, dtype=torch.float)
Out[6]:
tensor([0., 2., 4.])
In [7]:
def sinusoidal_pe(N, D, L=1000):
    pe = torch.zeros(N, D)
    div_term = L ** (2 * torch.arange(0, D, 2, dtype=torch.float) / D)
    pe[:, 0::2] = torch.sin(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    pe[:, 1::2] = torch.cos(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    return pe
In [8]:
n = 64; D = 128; L = 10
pe = sinusoidal_pe(n, D, L)
px.imshow(pe.cpu().numpy().T, aspect='auto', color_continuous_scale='RdBu_r',
          width = 1100, height=500)
In [9]:
dist = pe @ pe.T
fig = px.imshow(dist.cpu().numpy(), color_continuous_scale='Viridis',
                width=700, height=700,
                title='Dot Product of Positional Encodings')
fig.show()
px.line(x=np.arange(0, n), y=dist[10].cpu().numpy(), width=800, height=400,
        title='Dot Product of Positional Encodings for Position 200')

Tokenization

In [10]:
#!pip install transformers
In [11]:
from transformers import AutoTokenizer

# Load the Qwen tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")
tokenizer.vocab_size
Out[11]:
151643
In [12]:
tokenizer.encode("Hello, how are you?")
Out[12]:
[9707, 11, 1246, 525, 498, 30]

Byte Pair Encoding

The Byte Pair Encoding is fairly simple to implement. We start by splitting each word into its characters, appending a special end-of-word symbol </w> to the end of each word. Then, we repeatedly find the most common adjacent pair of symbols across all words and merge them into a new symbol. This process is repeated for a specified number of merges.

Once we have learned the merges, we can use them to tokenize new words by applying the merges in order until no more merges can be applied.

In [13]:
from typing import Dict, List, Tuple
from collections import Counter

EOW = "</w>"
In [14]:
def build_initial_vocab(corpus: str) -> Dict[str, int]:
    """
    Create an initial vocabulary consisting of all single characters
    observed in the corpus PLUS the EOW marker. IDs are assigned
    deterministically (sorted).

    Returns:
        vocab: dict mapping token string -> integer id
    """
    import string
    ALL_ASCII = set(string.ascii_letters + string.digits + string.punctuation)
    ALL_ASCII.add(EOW)  # include the end-of-word symbol
    corpus_set = set(corpus) - set(' \n\t\r')  # exclude whitespace

    return {tok: i for i, tok in enumerate(corpus_set | ALL_ASCII)}

vocab = build_initial_vocab("CS189 is CS.")
print(vocab)
{'R': 0, '?': 1, 'w': 2, '<': 3, 'N': 4, ')': 5, 'v': 6, 'h': 7, '=': 8, '9': 9, 'K': 10, ',': 11, 'G': 12, 'D': 13, 'q': 14, '"': 15, '4': 16, 'Q': 17, 'z': 18, 'd': 19, 'I': 20, 'J': 21, 'n': 22, '0': 23, 'Y': 24, 'E': 25, 'b': 26, ']': 27, 'L': 28, '`': 29, 'F': 30, '+': 31, 'r': 32, 'Z': 33, '.': 34, 'i': 35, 'x': 36, '3': 37, '8': 38, 'e': 39, 'O': 40, 'X': 41, '_': 42, 'y': 43, 'f': 44, '1': 45, 'p': 46, 'B': 47, '$': 48, 'C': 49, '!': 50, '2': 51, ':': 52, '>': 53, '[': 54, 'c': 55, '%': 56, 'm': 57, "'": 58, 'U': 59, '~': 60, '^': 61, 's': 62, '5': 63, '(': 64, 'P': 65, 't': 66, '\\': 67, 'W': 68, 'j': 69, '*': 70, '&': 71, 'k': 72, '|': 73, 'S': 74, 'l': 75, 'V': 76, '#': 77, 'o': 78, '}': 79, '@': 80, '-': 81, 'H': 82, 'a': 83, ';': 84, '7': 85, 'g': 86, '</w>': 87, 'T': 88, 'u': 89, '6': 90, '/': 91, '{': 92, 'M': 93, 'A': 94}
In [15]:
def corpus_to_char_seq_with_eow(corpus: str) -> List[str]:
    """
    Convert the whole corpus into a single flat sequence of symbols.
    Each word contributes its characters followed by the end-of-word marker `</w>`.
    
    Example:
        corpus = "low lower"
        returns: ['l','o','w','</w>','l','o','w','e','r','</w>']
    
    Why this matters:
    - We treat the corpus as *one long list* (not a list of per-word lists),
      which is sometimes more convenient for teaching and for demonstrating
      the role of `</w>` in preventing merges across word boundaries.
    """
    seq: List[str] = []
    for word in corpus.split():
        seq.extend(list(word))
        seq.append(EOW)
    return seq

print(corpus_to_char_seq_with_eow("CS189 is great!"))
['C', 'S', '1', '8', '9', '</w>', 'i', 's', '</w>', 'g', 'r', 'e', 'a', 't', '!', '</w>']
In [16]:
def count_pair_frequencies(seq: List[str]) -> Counter:
    """
    Count frequencies of adjacent symbol pairs over the *flat* sequence.
    
    Boundary rule:
    - We *disallow* pairs that START with `</w>` because that would cross a 
      word boundary on merge (i.e., merging `</w>` with the next word's first
      character). We DO allow pairs that END with `</w>` (e.g., ('w','</w>')),
      which forms tokens like 'w</w>' and is standard in BPE.
    
    Returns:
        A Counter mapping (left_symbol, right_symbol) -> count.
    """
    pair_counts = Counter()
    for i in range(len(seq) - 1):
        left, right = seq[i], seq[i + 1]
        if left.endswith(EOW): # This pair would cross a word boundary; skip it.
            continue
        pair_counts[(left, right)] += 1
    return pair_counts

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
print(pair_freqs)
Counter({('C', 'S'): 2, ('S', '1'): 1, ('1', '8'): 1, ('8', '9'): 1, ('9', '</w>'): 1, ('i', 's'): 1, ('s', '</w>'): 1, ('S', '.'): 1, ('.', '</w>'): 1})
In [17]:
def merge_pair_in_sequence(seq: List[str], pair: Tuple[str, str]) -> List[str]:
    """
    Perform a single merge of the given pair across the flat sequence.
    Invariant:
    - Never merge if the left symbol ends with `</w>` 
       (prevents crossing word boundaries).
    - Scans left-to-right and uses a simple skip mechanic to avoid overlapping merges.
    """
    a, b = pair
    merged_token = a + b
    new_seq: List[str] = []
    i = 0
    n = len(seq)
    while i < n:
        if i < n - 1 and seq[i] == a and seq[i + 1] == b and seq[i] != EOW:
            new_seq.append(merged_token)
            i += 2  # skip the merged pair
        else:
            new_seq.append(seq[i])
            i += 1
    return new_seq

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
pair, freq = pair_freqs.most_common(1)[0]
print("Merging pair:", pair, "with frequency:", freq)
new_seq = merge_pair_in_sequence(seq, pair)
print(new_seq)
Merging pair: ('C', 'S') with frequency: 2
['CS', '1', '8', '9', '</w>', 'i', 's', '</w>', 'CS', '.', '</w>']
In [18]:
from tqdm import tqdm  # pip install tqdm

def learn_bpe_merges(corpus: str, num_merges: int = 1000, 
                     min_frequency: int = 2) -> Tuple[List[Tuple[str, str]], dict]:
    """
    Learn BPE merge rules from the corpus by repeatedly finding the most frequent
    adjacent pair and merging it, subject to the boundary rule.

    Args:
        corpus: Raw text (spaces separate words).
        num_merges: Maximum number of merges to learn.
        min_frequency: Stop when the most frequent pair occurs fewer than this.

    Returns:
        merges: A list of (left_symbol, right_symbol) in the order they were learned.
        vocab: Final vocabulary mapping token -> id
    """
    seq = corpus_to_char_seq_with_eow(corpus)
    merges: List[Tuple[str, str]] = []
    vocab = build_initial_vocab(corpus)
    next_id = max(vocab.values()) + 1

    # Wrap the merge loop in a tqdm progress bar
    progress = tqdm(range(num_merges), desc="Learning BPE merges", ncols=80)

    for step in progress:
        pair_counts = count_pair_frequencies(seq)
        if not pair_counts:
            progress.set_postfix_str("done (no pairs left)")
            break
        (best_pair, best_count) = pair_counts.most_common(1)[0]
        if best_count < min_frequency:
            progress.set_postfix_str(f"stopped (min freq < {min_frequency})")
            break

        # Merge and update structures
        seq = merge_pair_in_sequence(seq, best_pair)
        merges.append(best_pair)
        new_token = best_pair[0] + best_pair[1]
        if new_token not in vocab:
            vocab[new_token] = next_id
            next_id += 1

        # Update the tqdm progress bar info
        progress.set_postfix_str(f"merge {best_pair} ({best_count})")

    progress.close()
    return merges, vocab

corpus = "This is the best CS class. This is CS 189."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
print("Learned merges:", merges)
print("Vocabulary:", vocab)
Learning BPE merges:   0%|                              | 0/100 [00:00<?, ?it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('i', 's') (4)]
Learning BPE merges:   0%|    | 0/100 [00:00<?, ?it/s, merge ('is', '</w>') (4)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('T', 'h') (2)]
Learning BPE merges:   0%|  | 0/100 [00:00<?, ?it/s, merge ('Th', 'is</w>') (2)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('C', 'S') (2)]
Learning BPE merges:   0%|    | 0/100 [00:00<?, ?it/s, merge ('CS', '</w>') (2)]
Learning BPE merges:   0%|     | 0/100 [00:00<?, ?it/s, merge ('.', '</w>') (2)]
Learning BPE merges:   0%|      | 0/100 [00:00<?, ?it/s, stopped (min freq < 2)]
Learning BPE merges:   7%| | 7/100 [00:00<00:00, 2457.94it/s, stopped (min freq 
Learned merges: [('i', 's'), ('is', '</w>'), ('T', 'h'), ('Th', 'is</w>'), ('C', 'S'), ('CS', '</w>'), ('.', '</w>')]
Vocabulary: {'R': 0, '?': 1, 'w': 2, '<': 3, 'N': 4, ')': 5, 'v': 6, 'h': 7, '=': 8, '9': 9, 'K': 10, ',': 11, 'G': 12, 'D': 13, 'q': 14, '"': 15, '4': 16, 'Q': 17, 'z': 18, 'd': 19, 'I': 20, 'J': 21, 'n': 22, '0': 23, 'Y': 24, 'E': 25, 'b': 26, ']': 27, 'L': 28, '`': 29, 'F': 30, '+': 31, 'r': 32, 'Z': 33, '.': 34, 'i': 35, 'x': 36, '3': 37, '8': 38, 'e': 39, 'O': 40, 'X': 41, '_': 42, 'y': 43, 'f': 44, '1': 45, 'p': 46, 'B': 47, '$': 48, 'C': 49, '!': 50, '2': 51, ':': 52, '>': 53, '[': 54, 'c': 55, '%': 56, 'm': 57, "'": 58, 'U': 59, '~': 60, '^': 61, 's': 62, '5': 63, '(': 64, 'P': 65, 't': 66, '\\': 67, 'W': 68, 'j': 69, '*': 70, '&': 71, 'k': 72, 'l': 73, 'S': 74, '|': 75, 'V': 76, '#': 77, 'o': 78, '}': 79, '@': 80, '-': 81, 'H': 82, 'a': 83, ';': 84, '7': 85, 'g': 86, '</w>': 87, 'T': 88, 'u': 89, '6': 90, '/': 91, '{': 92, 'M': 93, 'A': 94, 'is': 95, 'is</w>': 96, 'Th': 97, 'This</w>': 98, 'CS': 99, 'CS</w>': 100, '.</w>': 101}

In [19]:
def bpe_encode(text: str, merges: List[Tuple[str, str]], vocab: Dict[str, int]) -> List[int]:
    """
    Encode a string into token IDs:
      1) Convert text -> flat char+EOW sequence
      2) Apply learned merges in order
      3) Map final tokens to IDs via vocab

    Note:
    - This simple teaching encoder applies merges globally; it assumes the
      learned merges were derived from a similar distribution (your corpus).
    - For speed, production systems use a 'rank' map and greedy longest-match;
      here we stick to the clearest didactic approach.
    """
    progress = tqdm(range(len(merges)), desc="Applying BPE merges", ncols=80)
    seq = corpus_to_char_seq_with_eow(text)
    for a, b in merges:
        seq = merge_pair_in_sequence(seq, (a, b))
        progress.update(1)
    progress.close()
    return seq, [vocab[tok] for tok in seq]

corpus = "This is the best CS class. This is CS 189 the best class."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
encoded_seq, token_ids = bpe_encode("CS 189 is the   best   class.", merges, vocab)
print("Encoded sequence:", encoded_seq)
print("Token IDs:", token_ids)
Learning BPE merges:   0%|                              | 0/100 [00:00<?, ?it/s]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('i', 's') (4)]
Learning BPE merges:   0%|    | 0/100 [00:00<?, ?it/s, merge ('is', '</w>') (4)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('T', 'h') (2)]
Learning BPE merges:   0%|  | 0/100 [00:00<?, ?it/s, merge ('Th', 'is</w>') (2)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('t', 'h') (2)]
Learning BPE merges:   0%|       | 0/100 [00:00<?, ?it/s, merge ('th', 'e') (2)]
Learning BPE merges:   0%|   | 0/100 [00:00<?, ?it/s, merge ('the', '</w>') (2)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('b', 'e') (2)]
Learning BPE merges:   0%|       | 0/100 [00:00<?, ?it/s, merge ('be', 's') (2)]
Learning BPE merges:   0%|      | 0/100 [00:00<?, ?it/s, merge ('bes', 't') (2)]
Learning BPE merges:   0%|  | 0/100 [00:00<?, ?it/s, merge ('best', '</w>') (2)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('C', 'S') (2)]
Learning BPE merges:   0%|    | 0/100 [00:00<?, ?it/s, merge ('CS', '</w>') (2)]
Learning BPE merges:   0%|        | 0/100 [00:00<?, ?it/s, merge ('c', 'l') (2)]
Learning BPE merges:   0%|       | 0/100 [00:00<?, ?it/s, merge ('cl', 'a') (2)]
Learning BPE merges:   0%|      | 0/100 [00:00<?, ?it/s, merge ('cla', 's') (2)]
Learning BPE merges:   0%|     | 0/100 [00:00<?, ?it/s, merge ('clas', 's') (2)]
Learning BPE merges:   0%|    | 0/100 [00:00<?, ?it/s, merge ('class', '.') (2)]
Learning BPE merges:   0%| | 0/100 [00:00<?, ?it/s, merge ('class.', '</w>') (2)
Learning BPE merges:   0%|      | 0/100 [00:00<?, ?it/s, stopped (min freq < 2)]
Learning BPE merges:  19%|▏| 19/100 [00:00<00:00, 3921.45it/s, stopped (min freq

Applying BPE merges:   0%|                               | 0/19 [00:00<?, ?it/s]
Applying BPE merges: 100%|██████████████████| 19/19 [00:00<00:00, 332049.07it/s]
Encoded sequence: ['CS</w>', '1', '8', '9', '</w>', 'is</w>', 'the</w>', 'best</w>', 'class.</w>']
Token IDs: [107, 45, 38, 9, 87, 96, 101, 105, 113]

In [20]:
def bpe_decode(token_ids: List[int], vocab: Dict[str, int]) -> str:
    """
    Decode token IDs back to text by inverting the vocab and then
    removing EOW markers to re-insert spaces.

    Rules:
    - Tokens that END with EOW represent end-of-word units.
      We strip the trailing `</w>` and insert a space.
    - Other tokens are just literal substrings inside a word.

    Caveat:
    - Because we concatenated strings to form merged tokens, decoding simply
      concatenates their surfaces; then we rely on `</w>` to restore spaces.
    """
    inv_vocab = {i: t.replace(EOW, " ") for t, i in vocab.items()}
    out_words: List[str] = []
    buf = [inv_vocab[tid] for tid in token_ids]
    return "".join(buf).strip()
In [21]:
decoded_text = bpe_decode(token_ids, vocab)
print(f"Decoded text: \"{decoded_text}\"")
Decoded text: "CS 189 is the best class."

Implementing the Decoder Transformer for Generative Pre-training

Getting the Data

In [22]:
def fetch_and_cache_corpus(url: str, filename: str) -> str:
    import os
    if not os.path.exists(filename):
        print('downloading corpus...')
        import requests
        response = requests.get(url)
        corpus = response.text
        with open(filename, "w") as f:
            f.write(corpus)
    else:
        with open(filename, "r") as f:
            corpus = f.read()
    return corpus
In [23]:
corpus = fetch_and_cache_corpus(
    "https://www.gutenberg.org/cache/epub/1342/pg1342.txt",
    "jane_austen.txt"
)

# corpus = fetch_and_cache_corpus(
#     "https://www.gutenberg.org/cache/epub/100/pg100.txt",
#     "shakespear.txt")
    
print(f"Corpus length: {len(corpus)} characters") 
print(corpus[:1000])
downloading corpus...
Corpus length: 763043 characters
The Project Gutenberg eBook of Pride and Prejudice
    
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
of the Project Gutenberg License included with this ebook or online
at www.gutenberg.org. If you are not located in the United States,
you will have to check the laws of the country where you are located
before using this eBook.

Title: Pride and Prejudice

Author: Jane Austen

Release date: June 1, 1998 [eBook #1342]
                Most recently updated: September 22, 2025

Language: English

Credits: Chuck Greif and the Online Distributed Proofreading Team at http://www.pgdp.net (This file was produced from images available at The Internet Archive)


*** START OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***




                            [Illustration:

                           

Byte Pair Encoding

In [24]:
import os
if not os.path.exists("bpe_state.pkl"):
    print('learning BPE merges on corpus...')
    merges, vocab = learn_bpe_merges(corpus, 
                                     num_merges=1000, 
                                     min_frequency=2)
    with open("bpe_state.pkl", "wb") as f:
        pickle.dump((merges, vocab), f)
else:
    print("loading cached BPE state...")
    with open("bpe_state.pkl", "rb") as f:
        merges, vocab = pickle.load(f)
print("Learned merges:", merges)
print("Vocabulary:", vocab)
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)
loading cached BPE state...
Learned merges: [('e', '</w>'), ('t', 'h'), (',', '</w>'), ('.', '</w>'), ('t', '</w>'), ('s', '</w>'), ('d', '</w>'), ('e', 'r'), ('o', 'u'), ('i', 'n'), ('a', 'n'), ('y', '</w>'), ('o', 'r'), ('o', '</w>'), ('e', 'n'), ('a', 'r'), ('o', 'n'), ('l', 'l'), ('h', 'a'), ('th', 'e</w>'), ('f', '</w>'), ('i', 's</w>'), ('e', 's'), ('an', 'd</w>'), ('I', '</w>'), ('ll', '</w>'), ('y', 'ou'), ('er', '</w>'), ('e', 'a'), ('t', 'o</w>'), ('o', 'w'), ('e', ',</w>'), ('o', 'f</w>'), ('in', 'g'), ('w', 'i'), ('r', '</w>'), ('o', 'm'), ('s', 't'), ('th', '</w>'), ('a', '</w>'), ('c', 'h'), ('in', '</w>'), ('v', 'e</w>'), (';', '</w>'), ('or', '</w>'), ('T', 'h'), ('n', 'o'), ('h', 'i'), ('m', 'y</w>'), ('e', 'd</w>'), ('l', 'i'), ('?', '</w>'), ('a', 't</w>'), ('ing', '</w>'), ('th', 'e'), ('e', '.</w>'), ('r', 'i'), ('s', ',</w>'), ('r', 'e'), ('g', 'h'), ('en', '</w>'), ('A', 'n'), ('t', 'i'), ('o', 'o'), ('you', '</w>'), ('e', 'ar'), ('s', 't</w>'), ('s', 'e'), ('s', 'h'), ('d', ',</w>'), ('r', 'a'), ('on', '</w>'), ('m', 'a'), ('m', '</w>'), ('e', 's</w>'), ('’', 's</w>'), ('l', 'a'), ('no', 't</w>'), ('th', 'at</w>'), ('An', 'd</w>'), ('t', ',</w>'), ('l', 'd</w>'), ('ow', '</w>'), ('s', 'i'), ('m', 'e</w>'), ('y', ',</w>'), ('wi', 'th</w>'), ('an', '</w>'), ('u', 'n'), ('d', 'i'), ('!', '</w>'), ('k', '</w>'), ('u', 't</w>'), ('u', 'r'), ('a', 's</w>'), ('ha', 't</w>'), ('ou', '</w>'), ('ch', '</w>'), ('h', 'is</w>'), ('a', 'l'), ('i', 't</w>'), ('e', 'e'), ('k', 'e</w>'), ('b', 'e'), ('b', 'e</w>'), ('r', 'o'), ('er', 'e</w>'), ('s', 'e</w>'), ('you', 'r</w>'), ('c', 'e</w>'), ('f', 'or</w>'), ('i', 't'), ('S', '.</w>'), ('l', 'e'), ('l', 'o'), ('f', 'or'), ('ha', 've</w>'), ('f', 'a'), ('h', 'e</w>'), ('s', '.</w>'), ('a', 't'), ('s', 'a'), ('c', 'om'), ('n', '</w>'), ('th', 'is</w>'), ('ou', 'r</w>'), ('E', 'N'), ('d', '.</w>'), ('b', 'l'), ('r', 'ea'), ('ou', 'ld</w>'), ('w', 'h'), ('Th', 'e</w>'), ('s', 'ha'), ('s', 'p'), ('t', 'er</w>'), ('A', 'N'), ('th', 'ou</w>'), ('W', 'h'), ('i', 'r'), ('wi', 'll</w>'), ('e', 'l'), ('t', 'a'), ('E', 'R'), (']', '</w>'), ('s', 'u'), ('th', 'er'), ('th', 'y</w>'), ('om', '</w>'), ('A', 'R'), ('[', '_'), ('o', 'l'), ('en', 't'), ('ar', 'e</w>'), ('i', 'r</w>'), ('e', 'v'), ('w', 'or'), ('r', 'u'), ('I', 'N'), ('O', '.</w>'), ('a', 'll</w>'), ('’', 'd</w>'), ('b', 'ut</w>'), ('i', 'l'), ('T', 'o</w>'), ('.', '_'), ('U', 'S.</w>'), ('e', 't</w>'), ('k', 'n'), ('t', '.</w>'), ('y', '.</w>'), ('._', ']</w>'), ('h', 'er</w>'), ('oo', 'd</w>'), ('p', '</w>'), (':', '</w>'), ('d', 'a'), ('S', 'T'), ('e', 'll'), ('ear', '</w>'), ('i', 'gh'), ('th', 'er</w>'), ('ou', 'n'), ('hi', 'm</w>'), ('d', 'e'), ('m', 'e'), ('m', 'u'), ('sha', 'll</w>'), ('s', 'o</w>'), ('E', '.</w>'), ('w', 'a'), ('r', 'es'), ('l', 'y</w>'), ('m', 'or'), ('c', 'on'), ('t', 'er'), ('b', 'y</w>'), ('d', 'o'), ('O', 'N'), ('d', 'o</w>')]
Vocabulary: {')': 0, 'A': 1, 'è': 2, '[': 3, 'E': 4, 'M': 5, '@': 6, '}': 7, 'g': 8, '6': 9, '+': 10, 'P': 11, ']': 12, 'a': 13, 'Z': 14, 'ê': 15, 'î': 16, '5': 17, '8': 18, 'C': 19, '~': 20, '{': 21, '_': 22, '9': 23, 'e': 24, 'c': 25, '2': 26, '/': 27, '!': 28, 'I': 29, '7': 30, '1': 31, '0': 32, 'o': 33, 'æ': 34, "'": 35, 'K': 36, '`': 37, 'O': 38, 'G': 39, 'Æ': 40, 'É': 41, '>': 42, '•': 43, '$': 44, 'n': 45, 'h': 46, '—': 47, 'V': 48, 'm': 49, 'u': 50, 'd': 51, '™': 52, 'v': 53, '"': 54, 'Q': 55, ',': 56, 't': 57, 'z': 58, '*': 59, 'p': 60, 'ç': 61, 'ë': 62, '\\': 63, 'S': 64, 'B': 65, '’': 66, '</w>': 67, 'y': 68, 'x': 69, '<': 70, 'r': 71, 'f': 72, 'k': 73, 'L': 74, 'U': 75, '|': 76, 'b': 77, 'à': 78, 's': 79, '‘': 80, 'R': 81, '%': 82, 'œ': 83, 'i': 84, 'l': 85, '=': 86, '”': 87, '(': 88, 'D': 89, 'Y': 90, '\ufeff': 91, '4': 92, 'j': 93, '…': 94, 'H': 95, 'é': 96, 'À': 97, 'X': 98, 'Ç': 99, '-': 100, 'J': 101, 'w': 102, 'â': 103, '^': 104, 'W': 105, 'N': 106, ':': 107, '3': 108, '&': 109, '.': 110, 'q': 111, 'T': 112, '?': 113, '“': 114, '#': 115, ';': 116, 'F': 117, 'e</w>': 118, 'th': 119, ',</w>': 120, '.</w>': 121, 't</w>': 122, 's</w>': 123, 'd</w>': 124, 'er': 125, 'ou': 126, 'in': 127, 'an': 128, 'y</w>': 129, 'or': 130, 'o</w>': 131, 'en': 132, 'ar': 133, 'on': 134, 'll': 135, 'ha': 136, 'the</w>': 137, 'f</w>': 138, 'is</w>': 139, 'es': 140, 'and</w>': 141, 'I</w>': 142, 'll</w>': 143, 'you': 144, 'er</w>': 145, 'ea': 146, 'to</w>': 147, 'ow': 148, 'e,</w>': 149, 'of</w>': 150, 'ing': 151, 'wi': 152, 'r</w>': 153, 'om': 154, 'st': 155, 'th</w>': 156, 'a</w>': 157, 'ch': 158, 'in</w>': 159, 've</w>': 160, ';</w>': 161, 'or</w>': 162, 'Th': 163, 'no': 164, 'hi': 165, 'my</w>': 166, 'ed</w>': 167, 'li': 168, '?</w>': 169, 'at</w>': 170, 'ing</w>': 171, 'the': 172, 'e.</w>': 173, 'ri': 174, 's,</w>': 175, 're': 176, 'gh': 177, 'en</w>': 178, 'An': 179, 'ti': 180, 'oo': 181, 'you</w>': 182, 'ear': 183, 'st</w>': 184, 'se': 185, 'sh': 186, 'd,</w>': 187, 'ra': 188, 'on</w>': 189, 'ma': 190, 'm</w>': 191, 'es</w>': 192, '’s</w>': 193, 'la': 194, 'not</w>': 195, 'that</w>': 196, 'And</w>': 197, 't,</w>': 198, 'ld</w>': 199, 'ow</w>': 200, 'si': 201, 'me</w>': 202, 'y,</w>': 203, 'with</w>': 204, 'an</w>': 205, 'un': 206, 'di': 207, '!</w>': 208, 'k</w>': 209, 'ut</w>': 210, 'ur': 211, 'as</w>': 212, 'hat</w>': 213, 'ou</w>': 214, 'ch</w>': 215, 'his</w>': 216, 'al': 217, 'it</w>': 218, 'ee': 219, 'ke</w>': 220, 'be': 221, 'be</w>': 222, 'ro': 223, 'ere</w>': 224, 'se</w>': 225, 'your</w>': 226, 'ce</w>': 227, 'for</w>': 228, 'it': 229, 'S.</w>': 230, 'le': 231, 'lo': 232, 'for': 233, 'have</w>': 234, 'fa': 235, 'he</w>': 236, 's.</w>': 237, 'at': 238, 'sa': 239, 'com': 240, 'n</w>': 241, 'this</w>': 242, 'our</w>': 243, 'EN': 244, 'd.</w>': 245, 'bl': 246, 'rea': 247, 'ould</w>': 248, 'wh': 249, 'The</w>': 250, 'sha': 251, 'sp': 252, 'ter</w>': 253, 'AN': 254, 'thou</w>': 255, 'Wh': 256, 'ir': 257, 'will</w>': 258, 'el': 259, 'ta': 260, 'ER': 261, ']</w>': 262, 'su': 263, 'ther': 264, 'thy</w>': 265, 'om</w>': 266, 'AR': 267, '[_': 268, 'ol': 269, 'ent': 270, 'are</w>': 271, 'ir</w>': 272, 'ev': 273, 'wor': 274, 'ru': 275, 'IN': 276, 'O.</w>': 277, 'all</w>': 278, '’d</w>': 279, 'but</w>': 280, 'il': 281, 'To</w>': 282, '._': 283, 'US.</w>': 284, 'et</w>': 285, 'kn': 286, 't.</w>': 287, 'y.</w>': 288, '._]</w>': 289, 'her</w>': 290, 'ood</w>': 291, 'p</w>': 292, ':</w>': 293, 'da': 294, 'ST': 295, 'ell': 296, 'ear</w>': 297, 'igh': 298, 'ther</w>': 299, 'oun': 300, 'him</w>': 301, 'de': 302, 'me': 303, 'mu': 304, 'shall</w>': 305, 'so</w>': 306, 'E.</w>': 307, 'wa': 308, 'res': 309, 'ly</w>': 310, 'mor': 311, 'con': 312, 'ter': 313, 'by</w>': 314, 'do': 315, 'ON': 316, 'do</w>': 317}
Vocabulary size: 318
In [25]:
if not os.path.exists("encoded_text_ids.pkl"):
    print('encoding corpus...')
    encoded_seq, token_ids = bpe_encode(corpus, merges, vocab)
    with open("encoded_text_ids.pkl", "wb") as f:
        pickle.dump((encoded_seq,token_ids), f)
else:
    print("loading cached encoded text...")
    with open("encoded_text_ids.pkl", "rb") as f:
        encoded_seq, token_ids = pickle.load(f)
print("Encoded sequence length:", len(encoded_seq))
corpus_tokens = torch.tensor(token_ids, dtype=torch.long, device=device)

def encode(text: str) -> torch.Tensor:
    _, token_ids = bpe_encode(text, merges, vocab)
    return torch.tensor(token_ids, dtype=torch.long, device=device)

def decode(token_ids: torch.Tensor) -> str:
    return bpe_decode(token_ids.tolist(), vocab)

tok = encode("Know your own")
decode(tok)
loading cached encoded text...
Encoded sequence length: 2758830
Applying BPE merges:   0%|                              | 0/200 [00:00<?, ?it/s]
Applying BPE merges: 100%|████████████████| 200/200 [00:00<00:00, 913791.72it/s]

Out[25]:
'Know your own'

Word Encoding

In [26]:
import re
from collections import Counter

def split_words(corpus: str) -> List[str]:
    """
    Break the corpus into words using a regex that matches word characters.
    """
    pattern = r'\b\w+\b'
    corpus = re.sub(r'[._,!?;"\'`()\[\]{}<>]', '', corpus.lower())
    return re.findall(pattern, corpus.lower())

words = split_words(corpus)
# counter = Counter(words)
# vocab_set = {tok for tok, cnt in counter.items() if cnt > 1}
vocab_set = set(words)
vocab = {word: i for i, word in enumerate(sorted(vocab_set), start = 1)}
vocab["<unknown>"] = 1
inv_vocab = {i: word for word, i in vocab.items()}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)


def encode(text: str):
    """
    Encode a string into token IDs using the provided vocabulary.
    Unknown words are mapped to the ID for <unknown>.
    """
    words = split_words(text)
    return torch.tensor(
        [vocab.get(word, 1) for word in words],
        dtype=torch.long, device=device)

def decode(tokens: torch.Tensor) -> str:
    """
    Decode token IDs back to text by inverting the vocabulary.
    """
    words = [inv_vocab.get(t.item(), "<error>") for t in tokens]
    return " ".join(words)

corpus_tokens = encode(corpus)
print("Length of token IDs:", len(corpus_tokens))
decode(encode("to be or not to be that is the question tokenizer"))
Vocabulary size: 7123
Length of token IDs: 131762
Out[26]:
'to be or not to be that is the question <unknown>'

Data Preparation

In [27]:
N = corpus_tokens.shape[0]
seq_length = 64
split_ratio = 0.90
seed = 189
x = corpus_tokens[:(N - (N % (seq_length + 1)))]
x = x.reshape(-1, seq_length + 1)
y = x[:, 1:]
x = x[:, :-1]


from torch.utils.data import random_split, TensorDataset
dataset = TensorDataset(x, y)
generator = torch.Generator().manual_seed(seed)
training_data, validation_data = random_split(
    dataset, [split_ratio, 1 - split_ratio], 
    generator=generator) 
print("training contexts", len(training_data))
print("validation contexts", len(validation_data))
training_data[0]
training contexts 1825
validation contexts 202
Out[27]:
(tensor([3849,  765,  435, 4427, 3132, 4614, 3087,  579, 3132, 3252, 5759, 7041,
         2047, 3157, 6340, 3176, 4400, 3578, 3157, 6414, 3132, 2806, 6368, 6874,
          146, 6104, 4400, 1148, 2714, 6934, 5759, 6874, 5112, 6665,  435, 5759,
         1535, 3053, 6209,  146, 5886,  616, 3164,  767, 4332, 5666, 6340,  228,
         4400, 5921, 4400, 6374, 6777, 4638,  335, 6951, 3164, 4919, 3023, 5478,
         3337, 3164, 4408, 6414], device='mps:0'),
 tensor([ 765,  435, 4427, 3132, 4614, 3087,  579, 3132, 3252, 5759, 7041, 2047,
         3157, 6340, 3176, 4400, 3578, 3157, 6414, 3132, 2806, 6368, 6874,  146,
         6104, 4400, 1148, 2714, 6934, 5759, 6874, 5112, 6665,  435, 5759, 1535,
         3053, 6209,  146, 5886,  616, 3164,  767, 4332, 5666, 6340,  228, 4400,
         5921, 4400, 6374, 6777, 4638,  335, 6951, 3164, 4919, 3023, 5478, 3337,
         3164, 4408, 6414, 3144], device='mps:0'))
In [28]:
print(decode(training_data[1000][0]))
print(decode(training_data[1000][1]))
always down for the summer months except thought elizabeth when she goes to ramsgate if your master would marry you might see more of him yes sir but i do not know when that will be i do not know who is good enough for him mr and mrs gardiner smiled elizabeth could not help saying it is very much to his credit i
down for the summer months except thought elizabeth when she goes to ramsgate if your master would marry you might see more of him yes sir but i do not know when that will be i do not know who is good enough for him mr and mrs gardiner smiled elizabeth could not help saying it is very much to his credit i am

Attention

In [29]:
def scaled_dot_product_attention(Q: torch.Tensor, 
                                 K: torch.Tensor, 
                                 V: torch.Tensor, 
                                 mask=None):
    """
    Q: matrix of shape (B, N, d_k)
    K: matrix of shape (B, N, d_k)
    V: matrix of shape (B, N, d_v)
    mask: boolean matrix of shape (1, N, N). Values where mask is True will be INCLUDED
    """
    (B, N, d_k) = Q.shape
    (_, _, d_v) = V.shape
    K_T = K.transpose(-2, -1) # (B, d_k, N)
    dot_product = (Q @ K_T) / (d_k ** 0.5) # (B, N, N)
    
    if mask is not None:
        dot_product = dot_product.masked_fill(mask.logical_not(), float('-inf'))

    attention = F.softmax(dot_product, dim=-1) # (B, N, N)
    return attention @ V  # (B, N, N) * (B, N, d_v) = (B, N, d_v)
In [30]:
_mask_cache = {}
def get_mask_with_cache(N, device):
    """
    Returns a lower triangular mask of shape (1, N, N) to be used for masked attention.
    """
    if N not in _mask_cache:
        _mask_cache[N] = torch.ones(
            (N, N), dtype=torch.bool, 
            device=device).tril().unsqueeze(0)  
    return _mask_cache[N] #  (1, N, N)

class MaskedAttentionHead(nn.Module):
    def __init__(self, d_model=512, d_v=512, d_k=64):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        self.W_k = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_q = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_v = nn.Linear(self.d_model, self.d_v, bias = False)


    def forward(self, x):
        """
        x is the input to use for the queries, keys, and values
        encoder_output is the output from the encoder (used for cross-attention)
        mask is the mask to use for the attention
        """
        (B, N, _) = x.shape
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        mask = get_mask_with_cache(N, device=x.device)
        values = scaled_dot_product_attention(Q, K, V, mask=mask) 
        # ##  more efficient implementation:
        # ##  Need to unsqueeze the head dimension for F.scaled_dot_product_attention
        # values = F.scaled_dot_product_attention(
        #     query=Q, key=K, value=V, 
        #     attn_mask=mask)
        return values
In [31]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, num_heads=8, d_model=512, d_k=64):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_model // num_heads
        self.attention_heads = nn.ModuleList(
            [
                MaskedAttentionHead(d_model=self.d_model, d_v = self.d_v, d_k=self.d_k)
                for _ in range(self.num_heads)
            ]
        )
        # Projection
        self.W_out = nn.Linear(self.num_heads * self.d_v, self.d_model) 
        

    def forward(self, x):
        (B, N, _) = x.shape
        head_outputs = [head(x) for head in self.attention_heads]
        concatenated = torch.cat(head_outputs, dim=-1)
        out = self.W_out(concatenated)
        return out

Testing the Layer with Masked Multi-Head Attention:

In [32]:
emb = nn.Embedding(vocab_size, 512).to(device)
layer = MaskedMultiHeadAttention(7, 512, 64).to(device)
In [33]:
batch_size = 7
x, y = training_data[:batch_size]
emb(x).shape
Out[33]:
torch.Size([7, 64, 512])
In [34]:
layer(emb(x)).shape
Out[34]:
torch.Size([7, 64, 512])

Decoder Architecture

In [35]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_k=64, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ffn = 4 * self.d_model
        self.d_k = d_k

        self.dropout = nn.Dropout(dropout)

        self.mh_attention = MaskedMultiHeadAttention(
            num_heads=self.num_heads, 
            d_model=self.d_model,
            d_k=self.d_k)
        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_ffn),
            nn.ReLU(),
            self.dropout,
            nn.Linear(self.d_ffn, self.d_model),
        )
        self.layernorm1 = nn.LayerNorm(self.d_model)
        self.layernorm2 = nn.LayerNorm(self.d_model)
        
        
    
    def forward(self, x):
        mha = self.mh_attention(self.layernorm1(x)) 
        mha = self.dropout(mha)
        x = x + mha
        ffn = self.ffn(self.layernorm2(x)) 
        ffn = self.dropout(ffn)
        x = x + ffn
        return x
In [36]:
block = DecoderBlock(d_model=512, num_heads=8, d_k=64, dropout=0.1).to(device)
block(emb(x)).shape
Out[36]:
torch.Size([7, 64, 512])
In [37]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, max_len: int = 1024, L: float = 10000.0):
        """
        Sinusoidal positional encoding as in 'Attention is All You Need'.
        """
        super().__init__()

        pos = torch.zeros(max_len, d_model, dtype=torch.float32)
        positions = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
        div_terms = L ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
        quotient = positions / div_terms
        pos[:, 0::2] = torch.sin(quotient)  # even indices
        pos[:, 1::2] = torch.cos(quotient)  # odd indices
        # Register as non-parameter buffer so it moves with the module
        self.register_buffer("pos", pos)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        # pos: (1, seq_len, d_model) broadcasts along batch dimension
        return x + self.pos[:seq_len].unsqueeze(0)
In [38]:
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, max_len: int = 1024):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)
        nn.init.normal_(self.pe.weight, mean=0.0, std=1/(d_model ** 0.5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        (B, N, D) = x.shape
        positions = torch.arange(N, device=x.device).unsqueeze(0)
        return x + self.pe(positions)
In [39]:
class TransformerDecoderOnly(nn.Module):
    def __init__(self, 
                 max_length=1024,
                 vocab_size=6000, 
                 d_model=512,
                 d_k=64,
                 num_layers=6, 
                 num_heads=8, 
                 dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_k = d_k
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        # embedding model initialized with small values
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        nn.init.normal_(self.embedding.weight, mean=0.0, std=1.0 / (d_model**0.5))

        self.layers = nn.Sequential()
        self.layers.append(self.embedding)
        self.layers.append(
            # PositionalEncoding(d_model=d_model, max_len=max_length, L=10000)
            LearnedPositionalEncoding(d_model=d_model, max_len=max_length)
        )
        for _ in range(num_layers):
            self.layers.append(
                DecoderBlock(d_model=d_model, num_heads=num_heads, 
                             d_k=self.d_k, dropout=self.dropout)
            )
        self.layers.append(nn.LayerNorm(d_model))
        
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters())
    
    def forward(self, x):
        hidden = self.layers(x)
        return hidden @ self.embedding.weight.T 
    
    def generate(self, x, max_new_tokens):
        """
        x: (B, N) tensor of input token IDs
        max_new_tokens: number of tokens to generate
        """
        self.eval()
        with torch.no_grad():
            B, N = x.shape
            for _ in range(max_new_tokens):
                x = x[:, -seq_length:]  # crop to last seq_length tokens
                logits = self.forward(x)  # (B, N, vocab_size)
                next_token_logits = logits[:, -1, :]  # (B, vocab_size)
                next_token_probs = F.softmax(next_token_logits, dim=-1)  # (B, vocab_size)
                next_tokens = torch.multinomial(next_token_probs, num_samples=1)  # (B, 1)
                x = torch.cat([x, next_tokens], dim=1)  # (B, N+1)
        return x
In [40]:
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=256, d_k=16, num_layers=6, num_heads=8, 
    dropout=0.1).to(device)

print(model)
print("Number of parameters:", model.num_parameters()/1e6, "million")
TransformerDecoderOnly(
  (embedding): Embedding(7123, 256)
  (layers): Sequential(
    (0): Embedding(7123, 256)
    (1): LearnedPositionalEncoding(
      (pe): Embedding(64, 256)
    )
    (2): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (3): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (4): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (5): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (6): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (7): DecoderBlock(
      (dropout): Dropout(p=0.1, inplace=False)
      (mh_attention): MaskedMultiHeadAttention(
        (attention_heads): ModuleList(
          (0-7): 8 x MaskedAttentionHead(
            (W_k): Linear(in_features=256, out_features=16, bias=False)
            (W_q): Linear(in_features=256, out_features=16, bias=False)
            (W_v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (W_out): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
)
Number of parameters: 6.18112 million
In [41]:
decode(model.generate(encode("Love is").unsqueeze(0), max_new_tokens=20)[0])
Out[41]:
'love is uncompanionable vivacity separation speculation returned carved proportionate crossing hurrying increases such negligence passages living divert complaining intend rapacity hesitation trouble'

Training Loop

In [42]:
def batch_cross_entropy(pred, y):
    # flatten the batch into a single dimension and 
    # compute cross-entropy
    return F.cross_entropy(pred.view(-1, vocab_size), y.view(-1))
In [43]:
batch_cross_entropy(model(x), y)
Out[43]:
tensor(9.3094, device='mps:0', grad_fn=<NllLossBackward0>)
In [44]:
def minibatch_gd(model, loss_fn, 
                 training_data,
                 batch_size, 
                 nsteps, 
                 learning_rate,
                 visualizer=None,
                 weight_decay=1e-4):
    generator = torch.Generator()
    generator.manual_seed(189)
    loader = DataLoader(training_data, 
                        batch_size=batch_size, 
                        shuffle=True, # shuffles each epoch
                        generator=generator)
    
    # Define the optimizer (this is the update rule)
    # Alternatively, you can use Adam optimizer
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate, weight_decay=weight_decay)
    model.train() # set model to training mode (important for dropout/batchnorm)
    step = 0
    # Loop through the steps
    iter_loader = iter(loader)
    progress = tqdm(range(nsteps), desc="Tr:", ncols=100)
    for step in progress:
        # Get the next batch of data
        try:
            x, t = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            x, t = next(iter_loader)
        # Zero the gradients to start the next step
        optimizer.zero_grad()
        # Compute prediction and loss
        pred = model(x)
        loss = loss_fn(pred, t)
        tr_loss = loss.item()
        # Backpropagation (compute the gradient)
        loss.backward()
        # Update the parameters using the optimizer's update rule
        optimizer.step()
        # Visualize the model (if a visualizer function is provided)
        if visualizer is not None:
            model.eval() # disable dropout/batchnorm
            with torch.no_grad():
                visualizer(step, model, loss_fn, tr_loss, progress)
            model.train()
In [45]:
class LossVisualizer:
    def __init__(self, loss_fig, validation_data):
        self.loss_fig = loss_fig
        self.val_loader = DataLoader(validation_data, 
                                     batch_size=32, 
                                     shuffle=False)
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []

    def reset(self):
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = []
            self.loss_fig.data[0].y = []
            self.loss_fig.data[1].x = []
            self.loss_fig.data[1].y = []
    
    def __call__(self, epoch, model, loss_fn, loss_tr, progress):
        model.eval()
        with torch.no_grad():
            losses = []
            for x_val, t_val in self.val_loader:
                loss_val = loss_fn(model(x_val), t_val).item()
                losses.append(loss_val)
            loss_val = np.mean(losses)
        self.epochs.append(epoch)
        self.losses_val.append(loss_val)
        self.losses_tr.append(loss_tr)
        progress.set_postfix_str(
            f"tr.: {loss_tr:.5f}, val.: {loss_val:.5f}")
        # Visualization Code
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = self.epochs
            self.loss_fig.data[0].y = self.losses_val
            self.loss_fig.data[1].x = self.epochs
            self.loss_fig.data[1].y = self.losses_tr

        model.train()
In [46]:
loss_fig = go.FigureWidget()
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Val. Loss'))
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Train. Loss'))
visualizer = LossVisualizer(loss_fig, validation_data)
display(loss_fig)
In [47]:
visualizer.reset()
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=1024, d_k=32, num_layers=4, num_heads=8, 
    dropout=0.1).to(device)

#model = torch.compile(model)

# model = TransformerDecoderOnly(
#     max_length=seq_length, 
#     vocab_size=vocab_size, 
#     d_model=128, d_k=32, num_layers=8, num_heads=8, 
#     dropout=0.1).to(device)

minibatch_gd(
    model=model,
    loss_fn=batch_cross_entropy,
    training_data=training_data,
    batch_size=128,
    nsteps=150,
    learning_rate=3e-4,
    weight_decay=1e-4,
    visualizer=visualizer
)
Tr::   0%|                                                                  | 0/150 [00:00<?, ?it/s]
Tr::   0%|                                     | 0/150 [00:02<?, ?it/s, tr.: 9.26017, val.: 8.57399]
Tr::   1%|▏                            | 1/150 [00:02<06:35,  2.66s/it, tr.: 9.26017, val.: 8.57399]
Tr::   1%|▏                            | 1/150 [00:04<06:35,  2.66s/it, tr.: 8.51608, val.: 8.27931]
Tr::   1%|▍                            | 2/150 [00:04<05:49,  2.36s/it, tr.: 8.51608, val.: 8.27931]
Tr::   1%|▍                            | 2/150 [00:06<05:49,  2.36s/it, tr.: 8.20752, val.: 8.17923]
Tr::   2%|▌                            | 3/150 [00:06<05:25,  2.22s/it, tr.: 8.20752, val.: 8.17923]
Tr::   2%|▌                            | 3/150 [00:08<05:25,  2.22s/it, tr.: 8.15492, val.: 7.99680]
Tr::   3%|▊                            | 4/150 [00:08<05:14,  2.15s/it, tr.: 8.15492, val.: 7.99680]
Tr::   3%|▊                            | 4/150 [00:10<05:14,  2.15s/it, tr.: 8.03511, val.: 7.72505]
Tr::   3%|▉                            | 5/150 [00:10<05:06,  2.11s/it, tr.: 8.03511, val.: 7.72505]
Tr::   3%|▉                            | 5/150 [00:13<05:06,  2.11s/it, tr.: 7.72619, val.: 7.45971]
Tr::   4%|█▏                           | 6/150 [00:13<05:01,  2.09s/it, tr.: 7.72619, val.: 7.45971]
Tr::   4%|█▏                           | 6/150 [00:15<05:01,  2.09s/it, tr.: 7.47771, val.: 7.23197]
Tr::   5%|█▎                           | 7/150 [00:15<04:57,  2.08s/it, tr.: 7.47771, val.: 7.23197]
Tr::   5%|█▎                           | 7/150 [00:17<04:57,  2.08s/it, tr.: 7.20689, val.: 7.04672]
Tr::   5%|█▌                           | 8/150 [00:17<04:53,  2.07s/it, tr.: 7.20689, val.: 7.04672]
Tr::   5%|█▌                           | 8/150 [00:19<04:53,  2.07s/it, tr.: 7.06993, val.: 6.88123]
Tr::   6%|█▋                           | 9/150 [00:19<04:50,  2.06s/it, tr.: 7.06993, val.: 6.88123]
Tr::   6%|█▋                           | 9/150 [00:21<04:50,  2.06s/it, tr.: 6.85742, val.: 6.73982]
Tr::   7%|█▊                          | 10/150 [00:21<04:49,  2.07s/it, tr.: 6.85742, val.: 6.73982]
Tr::   7%|█▊                          | 10/150 [00:23<04:49,  2.07s/it, tr.: 6.73456, val.: 6.64857]
Tr::   7%|██                          | 11/150 [00:23<04:47,  2.07s/it, tr.: 6.73456, val.: 6.64857]
Tr::   7%|██                          | 11/150 [00:25<04:47,  2.07s/it, tr.: 6.64194, val.: 6.60583]
Tr::   8%|██▏                         | 12/150 [00:25<04:44,  2.06s/it, tr.: 6.64194, val.: 6.60583]
Tr::   8%|██▏                         | 12/150 [00:27<04:44,  2.06s/it, tr.: 6.61373, val.: 6.55918]
Tr::   9%|██▍                         | 13/150 [00:27<04:43,  2.07s/it, tr.: 6.61373, val.: 6.55918]
Tr::   9%|██▍                         | 13/150 [00:29<04:43,  2.07s/it, tr.: 6.59091, val.: 6.51956]
Tr::   9%|██▌                         | 14/150 [00:29<04:41,  2.07s/it, tr.: 6.59091, val.: 6.51956]
Tr::   9%|██▌                         | 14/150 [00:30<04:41,  2.07s/it, tr.: 6.49098, val.: 6.50559]
Tr::  10%|██▊                         | 15/150 [00:30<04:01,  1.79s/it, tr.: 6.49098, val.: 6.50559]
Tr::  10%|██▊                         | 15/150 [00:32<04:01,  1.79s/it, tr.: 6.40490, val.: 6.52665]
Tr::  11%|██▉                         | 16/150 [00:32<04:12,  1.89s/it, tr.: 6.40490, val.: 6.52665]
Tr::  11%|██▉                         | 16/150 [00:34<04:12,  1.89s/it, tr.: 6.46465, val.: 6.54742]
Tr::  11%|███▏                        | 17/150 [00:34<04:17,  1.93s/it, tr.: 6.46465, val.: 6.54742]
Tr::  11%|███▏                        | 17/150 [00:36<04:17,  1.93s/it, tr.: 6.39025, val.: 6.52070]
Tr::  12%|███▎                        | 18/150 [00:36<04:20,  1.97s/it, tr.: 6.39025, val.: 6.52070]
Tr::  12%|███▎                        | 18/150 [00:38<04:20,  1.97s/it, tr.: 6.44382, val.: 6.49633]
Tr::  13%|███▌                        | 19/150 [00:38<04:21,  2.00s/it, tr.: 6.44382, val.: 6.49633]
Tr::  13%|███▌                        | 19/150 [00:40<04:21,  2.00s/it, tr.: 6.42298, val.: 6.49545]
Tr::  13%|███▋                        | 20/150 [00:40<04:21,  2.01s/it, tr.: 6.42298, val.: 6.49545]
Tr::  13%|███▋                        | 20/150 [00:42<04:21,  2.01s/it, tr.: 6.47870, val.: 6.50568]
Tr::  14%|███▉                        | 21/150 [00:42<04:19,  2.01s/it, tr.: 6.47870, val.: 6.50568]
Tr::  14%|███▉                        | 21/150 [00:45<04:19,  2.01s/it, tr.: 6.40526, val.: 6.51623]
Tr::  15%|████                        | 22/150 [00:45<04:17,  2.01s/it, tr.: 6.40526, val.: 6.51623]
Tr::  15%|████                        | 22/150 [00:47<04:17,  2.01s/it, tr.: 6.46902, val.: 6.52116]
Tr::  15%|████▎                       | 23/150 [00:47<04:17,  2.03s/it, tr.: 6.46902, val.: 6.52116]
Tr::  15%|████▎                       | 23/150 [00:49<04:17,  2.03s/it, tr.: 6.45687, val.: 6.51970]
Tr::  16%|████▍                       | 24/150 [00:49<04:17,  2.05s/it, tr.: 6.45687, val.: 6.51970]
Tr::  16%|████▍                       | 24/150 [00:51<04:17,  2.05s/it, tr.: 6.52357, val.: 6.51229]
Tr::  17%|████▋                       | 25/150 [00:51<04:15,  2.05s/it, tr.: 6.52357, val.: 6.51229]
Tr::  17%|████▋                       | 25/150 [00:53<04:15,  2.05s/it, tr.: 6.43745, val.: 6.50206]
Tr::  17%|████▊                       | 26/150 [00:53<04:13,  2.04s/it, tr.: 6.43745, val.: 6.50206]
Tr::  17%|████▊                       | 26/150 [00:55<04:13,  2.04s/it, tr.: 6.45267, val.: 6.49171]
Tr::  18%|█████                       | 27/150 [00:55<04:10,  2.04s/it, tr.: 6.45267, val.: 6.49171]
Tr::  18%|█████                       | 27/150 [00:57<04:10,  2.04s/it, tr.: 6.51279, val.: 6.48364]
Tr::  19%|█████▏                      | 28/150 [00:57<04:08,  2.04s/it, tr.: 6.51279, val.: 6.48364]
Tr::  19%|█████▏                      | 28/150 [00:59<04:08,  2.04s/it, tr.: 6.44977, val.: 6.47999]
Tr::  19%|█████▍                      | 29/150 [00:59<04:06,  2.04s/it, tr.: 6.44977, val.: 6.47999]
Tr::  19%|█████▍                      | 29/150 [01:00<04:06,  2.04s/it, tr.: 6.39228, val.: 6.48065]
Tr::  20%|█████▌                      | 30/150 [01:00<03:31,  1.76s/it, tr.: 6.39228, val.: 6.48065]
Tr::  20%|█████▌                      | 30/150 [01:02<03:31,  1.76s/it, tr.: 6.30892, val.: 6.48404]
Tr::  21%|█████▊                      | 31/150 [01:02<03:40,  1.85s/it, tr.: 6.30892, val.: 6.48404]
Tr::  21%|█████▊                      | 31/150 [01:04<03:40,  1.85s/it, tr.: 6.35643, val.: 6.48654]
Tr::  21%|█████▉                      | 32/150 [01:04<03:49,  1.94s/it, tr.: 6.35643, val.: 6.48654]
Tr::  21%|█████▉                      | 32/150 [01:06<03:49,  1.94s/it, tr.: 6.39056, val.: 6.48681]
Tr::  22%|██████▏                     | 33/150 [01:06<03:51,  1.98s/it, tr.: 6.39056, val.: 6.48681]
Tr::  22%|██████▏                     | 33/150 [01:08<03:51,  1.98s/it, tr.: 6.41833, val.: 6.48502]
Tr::  23%|██████▎                     | 34/150 [01:08<03:52,  2.00s/it, tr.: 6.41833, val.: 6.48502]
Tr::  23%|██████▎                     | 34/150 [01:10<03:52,  2.00s/it, tr.: 6.34807, val.: 6.48221]
Tr::  23%|██████▌                     | 35/150 [01:10<03:52,  2.02s/it, tr.: 6.34807, val.: 6.48221]
Tr::  23%|██████▌                     | 35/150 [01:12<03:52,  2.02s/it, tr.: 6.40535, val.: 6.48021]
Tr::  24%|██████▋                     | 36/150 [01:12<03:51,  2.03s/it, tr.: 6.40535, val.: 6.48021]
Tr::  24%|██████▋                     | 36/150 [01:14<03:51,  2.03s/it, tr.: 6.46230, val.: 6.47989]
Tr::  25%|██████▉                     | 37/150 [01:14<03:50,  2.04s/it, tr.: 6.46230, val.: 6.47989]
Tr::  25%|██████▉                     | 37/150 [01:17<03:50,  2.04s/it, tr.: 6.42953, val.: 6.48023]
Tr::  25%|███████                     | 38/150 [01:17<03:51,  2.06s/it, tr.: 6.42953, val.: 6.48023]
Tr::  25%|███████                     | 38/150 [01:19<03:51,  2.06s/it, tr.: 6.41710, val.: 6.48023]
Tr::  26%|███████▎                    | 39/150 [01:19<03:49,  2.06s/it, tr.: 6.41710, val.: 6.48023]
Tr::  26%|███████▎                    | 39/150 [01:21<03:49,  2.06s/it, tr.: 6.43126, val.: 6.47949]
Tr::  27%|███████▍                    | 40/150 [01:21<03:46,  2.06s/it, tr.: 6.43126, val.: 6.47949]
Tr::  27%|███████▍                    | 40/150 [01:23<03:46,  2.06s/it, tr.: 6.41258, val.: 6.47832]
Tr::  27%|███████▋                    | 41/150 [01:23<03:44,  2.06s/it, tr.: 6.41258, val.: 6.47832]
Tr::  27%|███████▋                    | 41/150 [01:25<03:44,  2.06s/it, tr.: 6.38118, val.: 6.47707]
Tr::  28%|███████▊                    | 42/150 [01:25<03:41,  2.06s/it, tr.: 6.38118, val.: 6.47707]
Tr::  28%|███████▊                    | 42/150 [01:27<03:41,  2.06s/it, tr.: 6.37949, val.: 6.47611]
Tr::  29%|████████                    | 43/150 [01:27<03:39,  2.05s/it, tr.: 6.37949, val.: 6.47611]
Tr::  29%|████████                    | 43/150 [01:29<03:39,  2.05s/it, tr.: 6.38229, val.: 6.47540]
Tr::  29%|████████▏                   | 44/150 [01:29<03:36,  2.04s/it, tr.: 6.38229, val.: 6.47540]
Tr::  29%|████████▏                   | 44/150 [01:30<03:36,  2.04s/it, tr.: 6.44006, val.: 6.47491]
Tr::  30%|████████▍                   | 45/150 [01:30<03:04,  1.76s/it, tr.: 6.44006, val.: 6.47491]
Tr::  30%|████████▍                   | 45/150 [01:32<03:04,  1.76s/it, tr.: 6.41151, val.: 6.47578]
Tr::  31%|████████▌                   | 46/150 [01:32<03:12,  1.85s/it, tr.: 6.41151, val.: 6.47578]
Tr::  31%|████████▌                   | 46/150 [01:34<03:12,  1.85s/it, tr.: 6.38228, val.: 6.47734]
Tr::  31%|████████▊                   | 47/150 [01:34<03:16,  1.91s/it, tr.: 6.38228, val.: 6.47734]
Tr::  31%|████████▊                   | 47/150 [01:36<03:16,  1.91s/it, tr.: 6.39361, val.: 6.47886]
Tr::  32%|████████▉                   | 48/150 [01:36<03:18,  1.95s/it, tr.: 6.39361, val.: 6.47886]
Tr::  32%|████████▉                   | 48/150 [01:38<03:18,  1.95s/it, tr.: 6.39943, val.: 6.48033]
Tr::  33%|█████████▏                  | 49/150 [01:38<03:19,  1.97s/it, tr.: 6.39943, val.: 6.48033]
Tr::  33%|█████████▏                  | 49/150 [01:40<03:19,  1.97s/it, tr.: 6.39246, val.: 6.48105]
Tr::  33%|█████████▎                  | 50/150 [01:40<03:20,  2.00s/it, tr.: 6.39246, val.: 6.48105]
Tr::  33%|█████████▎                  | 50/150 [01:42<03:20,  2.00s/it, tr.: 6.32136, val.: 6.48118]
Tr::  34%|█████████▌                  | 51/150 [01:42<03:19,  2.02s/it, tr.: 6.32136, val.: 6.48118]
Tr::  34%|█████████▌                  | 51/150 [01:44<03:19,  2.02s/it, tr.: 6.35105, val.: 6.48088]
Tr::  35%|█████████▋                  | 52/150 [01:44<03:18,  2.03s/it, tr.: 6.35105, val.: 6.48088]
Tr::  35%|█████████▋                  | 52/150 [01:46<03:18,  2.03s/it, tr.: 6.38217, val.: 6.47985]
Tr::  35%|█████████▉                  | 53/150 [01:46<03:17,  2.03s/it, tr.: 6.38217, val.: 6.47985]
Tr::  35%|█████████▉                  | 53/150 [01:48<03:17,  2.03s/it, tr.: 6.37171, val.: 6.47898]
Tr::  36%|██████████                  | 54/150 [01:48<03:15,  2.04s/it, tr.: 6.37171, val.: 6.47898]
Tr::  36%|██████████                  | 54/150 [01:50<03:15,  2.04s/it, tr.: 6.37553, val.: 6.47787]
Tr::  37%|██████████▎                 | 55/150 [01:50<03:13,  2.04s/it, tr.: 6.37553, val.: 6.47787]
Tr::  37%|██████████▎                 | 55/150 [01:52<03:13,  2.04s/it, tr.: 6.38610, val.: 6.47697]
Tr::  37%|██████████▍                 | 56/150 [01:52<03:11,  2.04s/it, tr.: 6.38610, val.: 6.47697]
Tr::  37%|██████████▍                 | 56/150 [01:55<03:11,  2.04s/it, tr.: 6.41366, val.: 6.47603]
Tr::  38%|██████████▋                 | 57/150 [01:55<03:10,  2.05s/it, tr.: 6.41366, val.: 6.47603]
Tr::  38%|██████████▋                 | 57/150 [01:57<03:10,  2.05s/it, tr.: 6.36871, val.: 6.47519]
Tr::  39%|██████████▊                 | 58/150 [01:57<03:07,  2.04s/it, tr.: 6.36871, val.: 6.47519]
Tr::  39%|██████████▊                 | 58/150 [01:59<03:07,  2.04s/it, tr.: 6.36546, val.: 6.47410]
Tr::  39%|███████████                 | 59/150 [01:59<03:05,  2.04s/it, tr.: 6.36546, val.: 6.47410]
Tr::  39%|███████████                 | 59/150 [02:00<03:05,  2.04s/it, tr.: 6.43403, val.: 6.47300]
Tr::  40%|███████████▏                | 60/150 [02:00<02:37,  1.74s/it, tr.: 6.43403, val.: 6.47300]
Tr::  40%|███████████▏                | 60/150 [02:02<02:37,  1.74s/it, tr.: 6.35815, val.: 6.47367]
Tr::  41%|███████████▍                | 61/150 [02:02<02:43,  1.84s/it, tr.: 6.35815, val.: 6.47367]
Tr::  41%|███████████▍                | 61/150 [02:04<02:43,  1.84s/it, tr.: 6.36388, val.: 6.47567]
Tr::  41%|███████████▌                | 62/150 [02:04<02:48,  1.91s/it, tr.: 6.36388, val.: 6.47567]
Tr::  41%|███████████▌                | 62/150 [02:06<02:48,  1.91s/it, tr.: 6.37898, val.: 6.47770]
Tr::  42%|███████████▊                | 63/150 [02:06<02:52,  1.98s/it, tr.: 6.37898, val.: 6.47770]
Tr::  42%|███████████▊                | 63/150 [02:08<02:52,  1.98s/it, tr.: 6.32280, val.: 6.47990]
Tr::  43%|███████████▉                | 64/150 [02:08<02:52,  2.01s/it, tr.: 6.32280, val.: 6.47990]
Tr::  43%|███████████▉                | 64/150 [02:10<02:52,  2.01s/it, tr.: 6.34614, val.: 6.48144]
Tr::  43%|████████████▏               | 65/150 [02:10<02:51,  2.02s/it, tr.: 6.34614, val.: 6.48144]
Tr::  43%|████████████▏               | 65/150 [02:12<02:51,  2.02s/it, tr.: 6.39769, val.: 6.48202]
Tr::  44%|████████████▎               | 66/150 [02:12<02:51,  2.04s/it, tr.: 6.39769, val.: 6.48202]
Tr::  44%|████████████▎               | 66/150 [02:14<02:51,  2.04s/it, tr.: 6.35113, val.: 6.48154]
Tr::  45%|████████████▌               | 67/150 [02:14<02:49,  2.04s/it, tr.: 6.35113, val.: 6.48154]
Tr::  45%|████████████▌               | 67/150 [02:16<02:49,  2.04s/it, tr.: 6.33740, val.: 6.48119]
Tr::  45%|████████████▋               | 68/150 [02:16<02:47,  2.04s/it, tr.: 6.33740, val.: 6.48119]
Tr::  45%|████████████▋               | 68/150 [02:18<02:47,  2.04s/it, tr.: 6.40936, val.: 6.48089]
Tr::  46%|████████████▉               | 69/150 [02:18<02:45,  2.04s/it, tr.: 6.40936, val.: 6.48089]
Tr::  46%|████████████▉               | 69/150 [02:20<02:45,  2.04s/it, tr.: 6.38993, val.: 6.48091]
Tr::  47%|█████████████               | 70/150 [02:20<02:43,  2.04s/it, tr.: 6.38993, val.: 6.48091]
Tr::  47%|█████████████               | 70/150 [02:22<02:43,  2.04s/it, tr.: 6.41119, val.: 6.48102]
Tr::  47%|█████████████▎              | 71/150 [02:22<02:41,  2.04s/it, tr.: 6.41119, val.: 6.48102]
Tr::  47%|█████████████▎              | 71/150 [02:24<02:41,  2.04s/it, tr.: 6.41574, val.: 6.48149]
Tr::  48%|█████████████▍              | 72/150 [02:24<02:39,  2.04s/it, tr.: 6.41574, val.: 6.48149]
Tr::  48%|█████████████▍              | 72/150 [02:26<02:39,  2.04s/it, tr.: 6.39029, val.: 6.48120]
Tr::  49%|█████████████▋              | 73/150 [02:26<02:36,  2.04s/it, tr.: 6.39029, val.: 6.48120]
Tr::  49%|█████████████▋              | 73/150 [02:28<02:36,  2.04s/it, tr.: 6.40099, val.: 6.47983]
Tr::  49%|█████████████▊              | 74/150 [02:29<02:35,  2.05s/it, tr.: 6.40099, val.: 6.47983]
Tr::  49%|█████████████▊              | 74/150 [02:30<02:35,  2.05s/it, tr.: 6.39231, val.: 6.47906]
Tr::  50%|██████████████              | 75/150 [02:30<02:12,  1.76s/it, tr.: 6.39231, val.: 6.47906]
Tr::  50%|██████████████              | 75/150 [02:32<02:12,  1.76s/it, tr.: 6.35079, val.: 6.47933]
Tr::  51%|██████████████▏             | 76/150 [02:32<02:17,  1.86s/it, tr.: 6.35079, val.: 6.47933]
Tr::  51%|██████████████▏             | 76/150 [02:34<02:17,  1.86s/it, tr.: 6.35080, val.: 6.47965]
Tr::  51%|██████████████▎             | 77/150 [02:34<02:19,  1.91s/it, tr.: 6.35080, val.: 6.47965]
Tr::  51%|██████████████▎             | 77/150 [02:36<02:19,  1.91s/it, tr.: 6.39703, val.: 6.48046]
Tr::  52%|██████████████▌             | 78/150 [02:36<02:20,  1.95s/it, tr.: 6.39703, val.: 6.48046]
Tr::  52%|██████████████▌             | 78/150 [02:38<02:20,  1.95s/it, tr.: 6.41750, val.: 6.48173]
Tr::  53%|██████████████▋             | 79/150 [02:38<02:20,  1.98s/it, tr.: 6.41750, val.: 6.48173]
Tr::  53%|██████████████▋             | 79/150 [02:40<02:20,  1.98s/it, tr.: 6.38557, val.: 6.48357]
Tr::  53%|██████████████▉             | 80/150 [02:40<02:19,  1.99s/it, tr.: 6.38557, val.: 6.48357]
Tr::  53%|██████████████▉             | 80/150 [02:42<02:19,  1.99s/it, tr.: 6.36894, val.: 6.48554]
Tr::  54%|███████████████             | 81/150 [02:42<02:18,  2.01s/it, tr.: 6.36894, val.: 6.48554]
Tr::  54%|███████████████             | 81/150 [02:44<02:18,  2.01s/it, tr.: 6.36188, val.: 6.48658]
Tr::  55%|███████████████▎            | 82/150 [02:44<02:18,  2.03s/it, tr.: 6.36188, val.: 6.48658]
Tr::  55%|███████████████▎            | 82/150 [02:46<02:18,  2.03s/it, tr.: 6.39892, val.: 6.48652]
Tr::  55%|███████████████▍            | 83/150 [02:46<02:16,  2.04s/it, tr.: 6.39892, val.: 6.48652]
Tr::  55%|███████████████▍            | 83/150 [02:48<02:16,  2.04s/it, tr.: 6.36481, val.: 6.48604]
Tr::  56%|███████████████▋            | 84/150 [02:48<02:14,  2.04s/it, tr.: 6.36481, val.: 6.48604]
Tr::  56%|███████████████▋            | 84/150 [02:50<02:14,  2.04s/it, tr.: 6.38687, val.: 6.48493]
Tr::  57%|███████████████▊            | 85/150 [02:50<02:12,  2.04s/it, tr.: 6.38687, val.: 6.48493]
Tr::  57%|███████████████▊            | 85/150 [02:52<02:12,  2.04s/it, tr.: 6.32967, val.: 6.48391]
Tr::  57%|████████████████            | 86/150 [02:52<02:10,  2.04s/it, tr.: 6.32967, val.: 6.48391]
Tr::  57%|████████████████            | 86/150 [02:54<02:10,  2.04s/it, tr.: 6.36747, val.: 6.48233]
Tr::  58%|████████████████▏           | 87/150 [02:54<02:08,  2.04s/it, tr.: 6.36747, val.: 6.48233]
Tr::  58%|████████████████▏           | 87/150 [02:56<02:08,  2.04s/it, tr.: 6.39878, val.: 6.48025]
Tr::  59%|████████████████▍           | 88/150 [02:56<02:06,  2.04s/it, tr.: 6.39878, val.: 6.48025]
Tr::  59%|████████████████▍           | 88/150 [02:58<02:06,  2.04s/it, tr.: 6.35713, val.: 6.47796]
Tr::  59%|████████████████▌           | 89/150 [02:58<02:05,  2.06s/it, tr.: 6.35713, val.: 6.47796]
Tr::  59%|████████████████▌           | 89/150 [02:59<02:05,  2.06s/it, tr.: 6.36751, val.: 6.47583]
Tr::  60%|████████████████▊           | 90/150 [02:59<01:46,  1.78s/it, tr.: 6.36751, val.: 6.47583]
Tr::  60%|████████████████▊           | 90/150 [03:01<01:46,  1.78s/it, tr.: 6.33446, val.: 6.47511]
Tr::  61%|████████████████▉           | 91/150 [03:01<01:49,  1.86s/it, tr.: 6.33446, val.: 6.47511]
Tr::  61%|████████████████▉           | 91/150 [03:04<01:49,  1.86s/it, tr.: 6.37525, val.: 6.47486]
Tr::  61%|█████████████████▏          | 92/150 [03:04<01:53,  1.96s/it, tr.: 6.37525, val.: 6.47486]
Tr::  61%|█████████████████▏          | 92/150 [03:06<01:53,  1.96s/it, tr.: 6.37664, val.: 6.47503]
Tr::  62%|█████████████████▎          | 93/150 [03:06<02:00,  2.11s/it, tr.: 6.37664, val.: 6.47503]
Tr::  62%|█████████████████▎          | 93/150 [03:08<02:00,  2.11s/it, tr.: 6.35476, val.: 6.47419]
Tr::  63%|█████████████████▌          | 94/150 [03:08<02:00,  2.14s/it, tr.: 6.35476, val.: 6.47419]
Tr::  63%|█████████████████▌          | 94/150 [03:11<02:00,  2.14s/it, tr.: 6.34937, val.: 6.47269]
Tr::  63%|█████████████████▋          | 95/150 [03:11<01:59,  2.17s/it, tr.: 6.34937, val.: 6.47269]
Tr::  63%|█████████████████▋          | 95/150 [03:13<01:59,  2.17s/it, tr.: 6.33869, val.: 6.47105]
Tr::  64%|█████████████████▉          | 96/150 [03:13<01:58,  2.19s/it, tr.: 6.33869, val.: 6.47105]
Tr::  64%|█████████████████▉          | 96/150 [03:15<01:58,  2.19s/it, tr.: 6.31834, val.: 6.46993]
Tr::  65%|██████████████████          | 97/150 [03:15<01:56,  2.20s/it, tr.: 6.31834, val.: 6.46993]
Tr::  65%|██████████████████          | 97/150 [03:17<01:56,  2.20s/it, tr.: 6.33033, val.: 6.46788]
Tr::  65%|██████████████████▎         | 98/150 [03:17<01:54,  2.20s/it, tr.: 6.33033, val.: 6.46788]
Tr::  65%|██████████████████▎         | 98/150 [03:19<01:54,  2.20s/it, tr.: 6.35902, val.: 6.46463]
Tr::  66%|██████████████████▍         | 99/150 [03:19<01:50,  2.17s/it, tr.: 6.35902, val.: 6.46463]
Tr::  66%|██████████████████▍         | 99/150 [03:21<01:50,  2.17s/it, tr.: 6.41684, val.: 6.46023]
Tr::  67%|██████████████████         | 100/150 [03:22<01:48,  2.16s/it, tr.: 6.41684, val.: 6.46023]
Tr::  67%|██████████████████         | 100/150 [03:24<01:48,  2.16s/it, tr.: 6.36686, val.: 6.45515]
Tr::  67%|██████████████████▏        | 101/150 [03:24<01:45,  2.15s/it, tr.: 6.36686, val.: 6.45515]
Tr::  67%|██████████████████▏        | 101/150 [03:26<01:45,  2.15s/it, tr.: 6.33651, val.: 6.45097]
Tr::  68%|██████████████████▎        | 102/150 [03:26<01:41,  2.12s/it, tr.: 6.33651, val.: 6.45097]
Tr::  68%|██████████████████▎        | 102/150 [03:28<01:41,  2.12s/it, tr.: 6.35646, val.: 6.44793]
Tr::  69%|██████████████████▌        | 103/150 [03:28<01:38,  2.10s/it, tr.: 6.35646, val.: 6.44793]
Tr::  69%|██████████████████▌        | 103/150 [03:30<01:38,  2.10s/it, tr.: 6.34463, val.: 6.44461]
Tr::  69%|██████████████████▋        | 104/150 [03:30<01:36,  2.09s/it, tr.: 6.34463, val.: 6.44461]
Tr::  69%|██████████████████▋        | 104/150 [03:31<01:36,  2.09s/it, tr.: 6.39250, val.: 6.44054]
Tr::  70%|██████████████████▉        | 105/150 [03:31<01:20,  1.80s/it, tr.: 6.39250, val.: 6.44054]
Tr::  70%|██████████████████▉        | 105/150 [03:33<01:20,  1.80s/it, tr.: 6.30680, val.: 6.43773]
Tr::  71%|███████████████████        | 106/150 [03:33<01:22,  1.88s/it, tr.: 6.30680, val.: 6.43773]
Tr::  71%|███████████████████        | 106/150 [03:35<01:22,  1.88s/it, tr.: 6.27472, val.: 6.43530]
Tr::  71%|███████████████████▎       | 107/150 [03:35<01:22,  1.93s/it, tr.: 6.27472, val.: 6.43530]
Tr::  71%|███████████████████▎       | 107/150 [03:37<01:22,  1.93s/it, tr.: 6.31800, val.: 6.43184]
Tr::  72%|███████████████████▍       | 108/150 [03:37<01:23,  1.98s/it, tr.: 6.31800, val.: 6.43184]
Tr::  72%|███████████████████▍       | 108/150 [03:39<01:23,  1.98s/it, tr.: 6.26018, val.: 6.42614]
Tr::  73%|███████████████████▌       | 109/150 [03:39<01:22,  2.02s/it, tr.: 6.26018, val.: 6.42614]
Tr::  73%|███████████████████▌       | 109/150 [03:41<01:22,  2.02s/it, tr.: 6.33947, val.: 6.42059]
Tr::  73%|███████████████████▊       | 110/150 [03:41<01:21,  2.04s/it, tr.: 6.33947, val.: 6.42059]
Tr::  73%|███████████████████▊       | 110/150 [03:43<01:21,  2.04s/it, tr.: 6.26722, val.: 6.41580]
Tr::  74%|███████████████████▉       | 111/150 [03:43<01:20,  2.05s/it, tr.: 6.26722, val.: 6.41580]
Tr::  74%|███████████████████▉       | 111/150 [03:46<01:20,  2.05s/it, tr.: 6.26408, val.: 6.40998]
Tr::  75%|████████████████████▏      | 112/150 [03:46<01:18,  2.06s/it, tr.: 6.26408, val.: 6.40998]
Tr::  75%|████████████████████▏      | 112/150 [03:48<01:18,  2.06s/it, tr.: 6.31438, val.: 6.40106]
Tr::  75%|████████████████████▎      | 113/150 [03:48<01:19,  2.16s/it, tr.: 6.31438, val.: 6.40106]
Tr::  75%|████████████████████▎      | 113/150 [03:51<01:19,  2.16s/it, tr.: 6.31369, val.: 6.39174]
Tr::  76%|████████████████████▌      | 114/150 [03:51<01:31,  2.54s/it, tr.: 6.31369, val.: 6.39174]
Tr::  76%|████████████████████▌      | 114/150 [03:54<01:31,  2.54s/it, tr.: 6.30225, val.: 6.38241]
Tr::  77%|████████████████████▋      | 115/150 [03:54<01:33,  2.66s/it, tr.: 6.30225, val.: 6.38241]
Tr::  77%|████████████████████▋      | 115/150 [03:57<01:33,  2.66s/it, tr.: 6.26241, val.: 6.37256]
Tr::  77%|████████████████████▉      | 116/150 [03:57<01:29,  2.63s/it, tr.: 6.26241, val.: 6.37256]
Tr::  77%|████████████████████▉      | 116/150 [03:59<01:29,  2.63s/it, tr.: 6.24745, val.: 6.36299]
Tr::  78%|█████████████████████      | 117/150 [03:59<01:22,  2.50s/it, tr.: 6.24745, val.: 6.36299]
Tr::  78%|█████████████████████      | 117/150 [04:01<01:22,  2.50s/it, tr.: 6.26583, val.: 6.35345]
Tr::  79%|█████████████████████▏     | 118/150 [04:01<01:16,  2.38s/it, tr.: 6.26583, val.: 6.35345]
Tr::  79%|█████████████████████▏     | 118/150 [04:03<01:16,  2.38s/it, tr.: 6.25690, val.: 6.34343]
Tr::  79%|█████████████████████▍     | 119/150 [04:03<01:10,  2.28s/it, tr.: 6.25690, val.: 6.34343]
Tr::  79%|█████████████████████▍     | 119/150 [04:04<01:10,  2.28s/it, tr.: 6.28463, val.: 6.33458]
Tr::  80%|█████████████████████▌     | 120/150 [04:04<00:57,  1.92s/it, tr.: 6.28463, val.: 6.33458]
Tr::  80%|█████████████████████▌     | 120/150 [04:06<00:57,  1.92s/it, tr.: 6.20973, val.: 6.32611]
Tr::  81%|█████████████████████▊     | 121/150 [04:06<00:58,  2.01s/it, tr.: 6.20973, val.: 6.32611]
Tr::  81%|█████████████████████▊     | 121/150 [04:09<00:58,  2.01s/it, tr.: 6.16768, val.: 6.31758]
Tr::  81%|█████████████████████▉     | 122/150 [04:09<00:57,  2.05s/it, tr.: 6.16768, val.: 6.31758]
Tr::  81%|█████████████████████▉     | 122/150 [04:11<00:57,  2.05s/it, tr.: 6.15056, val.: 6.30986]
Tr::  82%|██████████████████████▏    | 123/150 [04:11<00:55,  2.05s/it, tr.: 6.15056, val.: 6.30986]
Tr::  82%|██████████████████████▏    | 123/150 [04:13<00:55,  2.05s/it, tr.: 6.19903, val.: 6.30275]
Tr::  83%|██████████████████████▎    | 124/150 [04:13<00:53,  2.05s/it, tr.: 6.19903, val.: 6.30275]
Tr::  83%|██████████████████████▎    | 124/150 [04:15<00:53,  2.05s/it, tr.: 6.15069, val.: 6.29647]
Tr::  83%|██████████████████████▌    | 125/150 [04:15<00:51,  2.05s/it, tr.: 6.15069, val.: 6.29647]
Tr::  83%|██████████████████████▌    | 125/150 [04:17<00:51,  2.05s/it, tr.: 6.14519, val.: 6.28868]
Tr::  84%|██████████████████████▋    | 126/150 [04:17<00:49,  2.05s/it, tr.: 6.14519, val.: 6.28868]
Tr::  84%|██████████████████████▋    | 126/150 [04:19<00:49,  2.05s/it, tr.: 6.12481, val.: 6.27966]
Tr::  85%|██████████████████████▊    | 127/150 [04:19<00:47,  2.05s/it, tr.: 6.12481, val.: 6.27966]
Tr::  85%|██████████████████████▊    | 127/150 [04:21<00:47,  2.05s/it, tr.: 6.19412, val.: 6.26930]
Tr::  85%|███████████████████████    | 128/150 [04:21<00:45,  2.05s/it, tr.: 6.19412, val.: 6.26930]
Tr::  85%|███████████████████████    | 128/150 [04:23<00:45,  2.05s/it, tr.: 6.15645, val.: 6.25746]
Tr::  86%|███████████████████████▏   | 129/150 [04:23<00:43,  2.05s/it, tr.: 6.15645, val.: 6.25746]
Tr::  86%|███████████████████████▏   | 129/150 [04:25<00:43,  2.05s/it, tr.: 6.14513, val.: 6.24536]
Tr::  87%|███████████████████████▍   | 130/150 [04:25<00:41,  2.06s/it, tr.: 6.14513, val.: 6.24536]
Tr::  87%|███████████████████████▍   | 130/150 [04:27<00:41,  2.06s/it, tr.: 6.08703, val.: 6.23102]
Tr::  87%|███████████████████████▌   | 131/150 [04:27<00:39,  2.08s/it, tr.: 6.08703, val.: 6.23102]
Tr::  87%|███████████████████████▌   | 131/150 [04:29<00:39,  2.08s/it, tr.: 6.11892, val.: 6.21586]
Tr::  88%|███████████████████████▊   | 132/150 [04:29<00:37,  2.07s/it, tr.: 6.11892, val.: 6.21586]
Tr::  88%|███████████████████████▊   | 132/150 [04:31<00:37,  2.07s/it, tr.: 6.08110, val.: 6.20246]
Tr::  89%|███████████████████████▉   | 133/150 [04:31<00:34,  2.06s/it, tr.: 6.08110, val.: 6.20246]
Tr::  89%|███████████████████████▉   | 133/150 [04:33<00:34,  2.06s/it, tr.: 6.08367, val.: 6.18954]
Tr::  89%|████████████████████████   | 134/150 [04:33<00:33,  2.08s/it, tr.: 6.08367, val.: 6.18954]
Tr::  89%|████████████████████████   | 134/150 [04:34<00:33,  2.08s/it, tr.: 6.00658, val.: 6.17737]
Tr::  90%|████████████████████████▎  | 135/150 [04:34<00:26,  1.80s/it, tr.: 6.00658, val.: 6.17737]
Tr::  90%|████████████████████████▎  | 135/150 [04:37<00:26,  1.80s/it, tr.: 5.97928, val.: 6.16352]
Tr::  91%|████████████████████████▍  | 136/150 [04:37<00:26,  1.91s/it, tr.: 5.97928, val.: 6.16352]
Tr::  91%|████████████████████████▍  | 136/150 [04:39<00:26,  1.91s/it, tr.: 5.96036, val.: 6.15136]
Tr::  91%|████████████████████████▋  | 137/150 [04:39<00:26,  2.01s/it, tr.: 5.96036, val.: 6.15136]
Tr::  91%|████████████████████████▋  | 137/150 [04:41<00:26,  2.01s/it, tr.: 5.96377, val.: 6.14167]
Tr::  92%|████████████████████████▊  | 138/150 [04:41<00:24,  2.06s/it, tr.: 5.96377, val.: 6.14167]
Tr::  92%|████████████████████████▊  | 138/150 [04:43<00:24,  2.06s/it, tr.: 5.93892, val.: 6.12982]
Tr::  93%|█████████████████████████  | 139/150 [04:43<00:23,  2.09s/it, tr.: 5.93892, val.: 6.12982]
Tr::  93%|█████████████████████████  | 139/150 [04:45<00:23,  2.09s/it, tr.: 6.05893, val.: 6.11714]
Tr::  93%|█████████████████████████▏ | 140/150 [04:45<00:21,  2.11s/it, tr.: 6.05893, val.: 6.11714]
Tr::  93%|█████████████████████████▏ | 140/150 [04:47<00:21,  2.11s/it, tr.: 5.95827, val.: 6.10700]
Tr::  94%|█████████████████████████▍ | 141/150 [04:47<00:18,  2.11s/it, tr.: 5.95827, val.: 6.10700]
Tr::  94%|█████████████████████████▍ | 141/150 [04:50<00:18,  2.11s/it, tr.: 6.02291, val.: 6.09442]
Tr::  95%|█████████████████████████▌ | 142/150 [04:50<00:16,  2.11s/it, tr.: 6.02291, val.: 6.09442]
Tr::  95%|█████████████████████████▌ | 142/150 [04:52<00:16,  2.11s/it, tr.: 5.95933, val.: 6.07900]
Tr::  95%|█████████████████████████▋ | 143/150 [04:52<00:14,  2.13s/it, tr.: 5.95933, val.: 6.07900]
Tr::  95%|█████████████████████████▋ | 143/150 [04:54<00:14,  2.13s/it, tr.: 5.88652, val.: 6.06254]
Tr::  96%|█████████████████████████▉ | 144/150 [04:54<00:12,  2.12s/it, tr.: 5.88652, val.: 6.06254]
Tr::  96%|█████████████████████████▉ | 144/150 [04:56<00:12,  2.12s/it, tr.: 5.86789, val.: 6.04795]
Tr::  97%|██████████████████████████ | 145/150 [04:56<00:10,  2.11s/it, tr.: 5.86789, val.: 6.04795]
Tr::  97%|██████████████████████████ | 145/150 [04:58<00:10,  2.11s/it, tr.: 5.88085, val.: 6.03679]
Tr::  97%|██████████████████████████▎| 146/150 [04:58<00:08,  2.11s/it, tr.: 5.88085, val.: 6.03679]
Tr::  97%|██████████████████████████▎| 146/150 [05:00<00:08,  2.11s/it, tr.: 5.90000, val.: 6.02413]
Tr::  98%|██████████████████████████▍| 147/150 [05:00<00:06,  2.09s/it, tr.: 5.90000, val.: 6.02413]
Tr::  98%|██████████████████████████▍| 147/150 [05:02<00:06,  2.09s/it, tr.: 5.92445, val.: 6.01237]
Tr::  99%|██████████████████████████▋| 148/150 [05:02<00:04,  2.09s/it, tr.: 5.92445, val.: 6.01237]
Tr::  99%|██████████████████████████▋| 148/150 [05:04<00:04,  2.09s/it, tr.: 5.83969, val.: 6.00430]
Tr::  99%|██████████████████████████▊| 149/150 [05:04<00:02,  2.08s/it, tr.: 5.83969, val.: 6.00430]
Tr::  99%|██████████████████████████▊| 149/150 [05:05<00:02,  2.08s/it, tr.: 5.90774, val.: 5.99814]
Tr:: 100%|███████████████████████████| 150/150 [05:05<00:00,  1.79s/it, tr.: 5.90774, val.: 5.99814]
Tr:: 100%|███████████████████████████| 150/150 [05:05<00:00,  2.04s/it, tr.: 5.90774, val.: 5.99814]

In [48]:
decode(model.generate(encode("Love is").unsqueeze(0), max_new_tokens=20)[0])
Out[48]:
'love is your believe on very been delicacy orders at in a always complexion in tenderly turn of what it s very'
In [ ]: