Lecture 18 – CS 189, Fall 2025
In [1]:
import os, math, itertools, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from types import SimpleNamespace
from torchvision.utils import make_grid
from sklearn.manifold import TSNE
try:
import torchvision as tv
from torchvision import transforms
except Exception as e:
torchvision = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cpu
In [2]:
# Load the dataset
if tv is None:
print("torchvision not available. Please install torchvision to run MNIST demos.")
else:
transform = transforms.Compose([transforms.ToTensor()]) # Keep MNIST in [0,1], single-channel
train_ds = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ds = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
print('Train size:', len(train_ds), ' Test size:', len(test_ds))
Train size: 60000 Test size: 10000
In [3]:
# Peek at MNIST
import matplotlib.pyplot as plt
if tv is not None:
imgs = [train_ds[i][0] for i in range(16)]
labels = [train_ds[i][1] for i in range(16)]
fig, axes = plt.subplots(4,4, figsize=(4,4))
for ax, img, lab in zip(axes.flatten(), imgs, labels):
ax.imshow(img[0].numpy(), cmap='gray')
ax.set_title(str(lab))
ax.axis('off')
plt.tight_layout(); plt.show()
In [4]:
# Identify the first 10 images of the digit 9
imgs_9 = [train_ds[i][0] for i in range(len(train_ds)) if train_ds[i][1] == 9][:10]
# Plot the images
fig, axes = plt.subplots(1, 10, figsize=(15, 5))
for i, img in enumerate(imgs_9):
ax = axes[i]
ax.imshow(img[0].numpy(), cmap='gray')
ax.axis('off')
ax.set_title("9")
plt.tight_layout()
plt.show()
In [5]:
x = imgs_9[4].unsqueeze(0)
# Define convolution kernels
kernels = {
'horizontal line': torch.tensor([[[-1, -1, -1], [2, 2, 2], [-1, -1, -1]]], dtype=torch.float32).unsqueeze(0),
'vertical line': torch.tensor([[[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]]], dtype=torch.float32).unsqueeze(0),
'diagonal line': torch.tensor([[[2, -1, -1], [-1, 2, -1], [-1, -1, 2]]], dtype=torch.float32).unsqueeze(0)
}
# Apply convolutions and visualize
fig, axes = plt.subplots(len(kernels), 3, figsize=(9, len(kernels) * 3))
for i, (name, kernel) in enumerate(kernels.items()):
# Display the kernel
axes[i, 0].imshow(kernel.squeeze().detach().numpy(), cmap='gray')
axes[i, 0].set_title(f'Kernel: {name}')
axes[i, 0].axis('off')
# Display the input image
axes[i, 1].imshow(x.squeeze().detach().numpy(), cmap='gray')
axes[i, 1].set_title('Input Image')
axes[i, 1].axis('off')
# Display the result of the convolution
y = F.conv2d(x, kernel, padding=1)
axes[i, 2].imshow(y.squeeze().detach().numpy(), cmap='gray')
axes[i, 2].set_title('Convolved Output')
axes[i, 2].axis('off')
plt.tight_layout()
plt.show()
In [6]:
if tv is not None:
x, y = [train_ds[i] for i in range(len(train_ds)) if train_ds[i][1] == 9][4] # x: [1,28,28]
x = x.unsqueeze(0) # [1,1,28,28]
# Define classic kernels (normalized where reasonable)
kernels = {
'identity': torch.tensor([[0,0,0],[0,1,0],[0,0,0]], dtype=torch.float32),
'edge_h': torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32),
'edge_v': torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32),
'sharpen': torch.tensor([[0,-1,0],[-1,5,-1],[0,-1,0]], dtype=torch.float32),
'box_blur': (1/9.0)*torch.ones((3,3), dtype=torch.float32)
}
fig, axes = plt.subplots(2, 3, figsize=(6,4))
axes = axes.flatten()
axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title('Input'); axes[0].axis('off')
i = 1
for name, K in kernels.items():
W = K.view(1,1,3,3)
yk = F.conv2d(x, W, padding=1) # keep size with padding=1
axes[i].imshow(yk[0,0].detach().numpy(), cmap='gray')
axes[i].set_title(name); axes[i].axis('off')
i += 1
if i >= len(axes):
break
plt.tight_layout(); plt.show()
else:
print("torchvision not available.")
In [7]:
if tv is not None:
x, _ = train_ds[1]
x = x.unsqueeze(0)
k = 3
W = torch.ones(1,1,k,k) / (k*k)
configs = [(1,0), (1,1), (2,0), (2,1)] # (stride, padding)
fig, axes = plt.subplots(1, len(configs)+1, figsize=(3*(len(configs)+1), 3))
axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title('Input'); axes[0].axis('off')
for i, (s,p) in enumerate(configs, start=1):
y = F.conv2d(x, W, stride=s, padding=p)
axes[i].imshow(y[0,0].detach().numpy(), cmap='gray')
axes[i].set_title(f's={s}, p={p}\n{tuple(y.shape)}'); axes[i].axis('off')
plt.tight_layout()
plt.show()
In [8]:
if tv is not None:
x, _ = train_ds[2]
x = x.unsqueeze(0)
maxpool = nn.MaxPool2d(2,2)
avgpool = nn.AvgPool2d(4,4)
y_max = maxpool(x)
y_avg = avgpool(x)
fig, axes = plt.subplots(1, 3, figsize=(9,3))
axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title(f'd={len(x[0,0])}x{len(x[0,0])}\n Input'); axes[0].axis('off')
axes[1].imshow(y_max[0,0].detach().numpy(), cmap='gray'); axes[1].set_title(f'd={len(y_max[0,0])}x {len(y_max[0,0])}\n MaxPool 2x2'); axes[1].axis('off')
axes[2].imshow(y_avg[0,0].detach().numpy(), cmap='gray'); axes[2].set_title(f'd={len(y_avg[0,0])}x {len(y_avg[0,0])}\n AvgPool 4x4'); axes[2].axis('off')
plt.tight_layout(); plt.show()
In [9]:
class SmallCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(8, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64*7*7, 256), # Adjusted based on the output size of the features block
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, 10)
)
def forward(self, x):
x = self.features(x)
return self.classifier(x)
model = SmallCNN().to(device)
summary(model, input_size=(1, 1, 28, 28), # (batch, C, H, W)
col_names=("input_size","output_size","num_params","kernel_size"),
depth=4)
Out[9]:
============================================================================================================================================ Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape ============================================================================================================================================ SmallCNN [1, 1, 28, 28] [1, 10] -- -- ├─Sequential: 1-1 [1, 1, 28, 28] [1, 64, 7, 7] -- -- │ └─Conv2d: 2-1 [1, 1, 28, 28] [1, 8, 28, 28] 80 [3, 3] │ └─ReLU: 2-2 [1, 8, 28, 28] [1, 8, 28, 28] -- -- │ └─MaxPool2d: 2-3 [1, 8, 28, 28] [1, 8, 14, 14] -- 2 │ └─Conv2d: 2-4 [1, 8, 14, 14] [1, 64, 14, 14] 4,672 [3, 3] │ └─ReLU: 2-5 [1, 64, 14, 14] [1, 64, 14, 14] -- -- │ └─MaxPool2d: 2-6 [1, 64, 14, 14] [1, 64, 7, 7] -- 2 ├─Sequential: 1-2 [1, 64, 7, 7] [1, 10] -- -- │ └─Flatten: 2-7 [1, 64, 7, 7] [1, 3136] -- -- │ └─Linear: 2-8 [1, 3136] [1, 256] 803,072 -- │ └─ReLU: 2-9 [1, 256] [1, 256] -- -- │ └─Dropout: 2-10 [1, 256] [1, 256] -- -- │ └─Linear: 2-11 [1, 256] [1, 10] 2,570 -- ============================================================================================================================================ Total params: 810,394 Trainable params: 810,394 Non-trainable params: 0 Total mult-adds (M): 1.78 ============================================================================================================================================ Input size (MB): 0.00 Forward/backward pass size (MB): 0.15 Params size (MB): 3.24 Estimated Total Size (MB): 3.40 ============================================================================================================================================
In [10]:
def train_one_epoch(model, loader, opt, loss_fn):
model.train()
total, correct, running_loss = 0, 0, 0.0
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad()
logits = model(xb)
loss = loss_fn(logits, yb)
loss.backward()
opt.step()
running_loss += loss.item()*xb.size(0)
preds = logits.argmax(dim=1)
correct += (preds==yb).sum().item()
total += xb.size(0)
return running_loss/total, correct/total
@torch.no_grad()
def evaluate(model, loader, loss_fn):
model.eval()
total, correct, running_loss = 0, 0, 0.0
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
loss = loss_fn(logits, yb)
running_loss += loss.item()*xb.size(0)
preds = logits.argmax(dim=1)
correct += (preds==yb).sum().item()
total += xb.size(0)
return running_loss/total, correct/total
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
if tv is not None:
model = SmallCNN().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
EPOCHS = 4
for epoch in range(1, EPOCHS+1):
tl, ta = train_one_epoch(model, train_loader, opt, loss_fn)
vl, va = evaluate(model, test_loader, loss_fn)
history['epoch'].append(epoch)
history['train_loss'].append(tl); history['val_loss'].append(vl)
history['train_acc'].append(ta); history['val_acc'].append(va)
print(f'E{epoch}: train_loss={tl:.4f} val_loss={vl:.4f} train_acc={ta:.3f} val_acc={va:.3f}')
else:
print("torchvision not available. Skipping training.")
E1: train_loss=0.2968 val_loss=0.0817 train_acc=0.910 val_acc=0.975
E2: train_loss=0.0768 val_loss=0.0519 train_acc=0.977 val_acc=0.982
E3: train_loss=0.0545 val_loss=0.0390 train_acc=0.983 val_acc=0.986
E4: train_loss=0.0430 val_loss=0.0330 train_acc=0.987 val_acc=0.989
In [11]:
if history and 'train_loss' in history and 'val_loss' in history:
plt.figure(); plt.plot(history['epoch'], history['train_loss']); plt.plot(history['epoch'], history['val_loss']); plt.legend(['train','val']); plt.title('Loss'); plt.xlabel('epoch'); plt.show()
if history and 'train_acc' in history and 'val_acc' in history:
plt.figure(); plt.plot(history['epoch'], history['train_acc']); plt.plot(history['epoch'], history['val_acc']); plt.legend(['train','val']); plt.title('Accuracy'); plt.xlabel('epoch'); plt.show()
In [12]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np
if tv is not None:
model.eval()
all_true, all_pred = [], []
with torch.no_grad():
for xb, yb in test_loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
preds = logits.argmax(dim=1)
all_pred.extend(preds.cpu().numpy())
all_true.extend(yb.cpu().numpy())
# Calculate accuracy
all_true = np.array(all_true)
all_pred = np.array(all_pred)
accuracy = (all_true == all_pred).sum() / len(all_true)
print(f"Accuracy on test data: {accuracy:.4f}")
# Plot confusion matrix
cm = confusion_matrix(all_true, all_pred, labels=list(range(10)))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(range(10)))
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()
else:
print("torchvision not available. Cannot evaluate predictions.")
Accuracy on test data: 0.9886
In [13]:
alex_w = tv.models.AlexNet_Weights.IMAGENET1K_V1
MODELS = {
"alexnet": SimpleNamespace(
ctor=lambda: tv.models.alexnet(weights=alex_w).to(device).eval(),
weights=alex_w,
act_layers=["features.0", "features.3"], # conv blocks
maxact_layer="features.10",
arch="alexnet"
),
}
In [14]:
def load_image(path, weights):
img = Image.open(path).convert('RGB')
return weights.transforms()(img).unsqueeze(0).to(device)
def show(tensor, title=None):
if tensor.ndim == 4:
grid = make_grid(tensor, nrow=int(np.ceil(np.sqrt(tensor.size(0)))))
arr = grid.permute(1,2,0).detach().cpu().numpy()
else:
arr = tensor.permute(1,2,0).detach().cpu().numpy()
plt.figure(figsize=(6,6))
plt.imshow(np.clip(arr, 0, 1))
plt.axis('off')
if title: plt.title(title)
plt.show()
# Robust per-item normalization (avoids tuple-dim min/max issues)
def _norm_per_item(t):
# t shape: [N, ...]
if hasattr(torch, "amin"):
tmin = torch.amin(t, dim=tuple(range(1, t.ndim)), keepdim=True)
tmax = torch.amax(t, dim=tuple(range(1, t.ndim)), keepdim=True)
else:
flat = t.view(t.size(0), -1)
tmin = flat.min(dim=1, keepdim=True)[0].view(-1, *([1]*(t.ndim-1)))
tmax = flat.max(dim=1, keepdim=True)[0].view(-1, *([1]*(t.ndim-1)))
return (t - tmin) / (tmax - tmin + 1e-8)
# Helper: resolve "features.23" dotted path to a module
def resolve_module(root, name):
mod = root
for part in name.split('.'):
if part.isdigit():
mod = mod[int(part)]
else:
mod = getattr(mod, part)
return mod
def first_conv_module(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
return m
raise RuntimeError("No Conv2d found.")
In [15]:
def visualize_first_layer_filters(model, max_filters=64, label="model"):
conv1 = None
# Try common entry points first
for attr in ["conv1", "features"]:
if hasattr(model, attr):
m = getattr(model, attr)
if isinstance(m, nn.Conv2d):
conv1 = m
break
# If Sequential, first layer likely Conv2d
if isinstance(m, nn.Sequential):
for x in m:
if isinstance(x, nn.Conv2d):
conv1 = x; break
if conv1 is not None: break
if conv1 is None:
conv1 = first_conv_module(model)
w = conv1.weight.data.clone().cpu() # [out, in, k, k]
w = _norm_per_item(w)
show(w[:max_filters], title=f"{label}: first-layer conv filters")
In [16]:
def visualize_activations(model, img_tensor, layer_names, label="model"):
feats, hooks = {}, []
def hook(name):
return lambda m, i, o: feats.setdefault(name, o.detach().cpu())
# Register hooks
for name in layer_names:
try:
module = resolve_module(model, name)
hooks.append(module.register_forward_hook(hook(name)))
except Exception as e:
print(f"[warn] could not hook '{name}': {e}")
with torch.no_grad():
_ = model(img_tensor)
for h in hooks: h.remove()
for name, feat in feats.items():
fmap = feat[0]
if fmap.ndim != 3:
print(f"[info] {label}:{name} is non-spatial (shape {feat.shape}), skipping grid")
continue
C = min(64, fmap.size(0))
fm = fmap[:C]
fm = _norm_per_item(fm.unsqueeze(1)).squeeze(1)
show(fm.unsqueeze(1), title=f"{label}: activations @ {name}")
In [17]:
@torch.no_grad()
def predict_probs(model, x):
logits = model(x)
return F.softmax(logits, dim=1)
def _resize_heat_with_torch(heat, H, W):
t = torch.from_numpy(heat)[None, None]
t = F.interpolate(t.float(), size=(H, W), mode="bilinear", align_corners=False)
return t[0,0].numpy()
def occlusion_heatmap(model, img_tensor, idx_to_label=None, target_class=None, patch=32, stride=16, baseline=0.0, label="model"):
model.eval()
x = img_tensor.clone()
probs = predict_probs(model, x)[0]
if target_class is None:
target_class = probs.argmax().item()
base_p = probs[target_class].item()
_, _, H, W = x.shape
heat = np.zeros(((H - patch)//stride + 1, (W - patch)//stride + 1), dtype=np.float32)
for i, y in enumerate(range(0, H - patch + 1, stride)):
for j, z in enumerate(range(0, W - patch + 1, stride)):
x_ = x.clone()
x_[:,:, y:y+patch, z:z+patch] = baseline
p = predict_probs(model, x_)[0, target_class].item()
heat[i, j] = base_p - p
heat_resized = _resize_heat_with_torch(heat, H, W)
# quick unnormalize for show (using Imagenet stats)
im = x[0].detach().cpu()
im = (im * torch.tensor([0.229,0.224,0.225])[:,None,None] + torch.tensor([0.485,0.456,0.406])[:,None,None]).permute(1,2,0).numpy()
plt.figure(figsize=(6,6)); plt.imshow(np.clip(im,0,1)); plt.imshow(heat_resized, alpha=0.5); plt.axis('off')
if idx_to_label:
tname = idx_to_label[target_class]
else:
tname = str(target_class)
plt.title(f"{label}: occlusion (target='{tname}', base p={base_p:.3f})")
plt.show()
return heat_resized
In [18]:
def saliency_map(model, img_tensor, target_class=None, label="model"):
model.eval()
x = img_tensor.clone().requires_grad_(True)
logits = model(x)
if target_class is None:
target_class = logits.argmax(dim=1).item()
loss = logits[0, target_class]
model.zero_grad()
loss.backward()
g = x.grad.detach()[0] # [3,H,W]
sal = g.abs().max(dim=0)[0] # [H,W]
sal = (sal - sal.min())/(sal.max()-sal.min()+1e-8)
plt.figure(figsize=(6,6)); plt.imshow(sal.cpu(), cmap='gray'); plt.axis('off'); plt.title(f"{label}: saliency")
plt.show()
return sal
class GuidedBackpropReLU(nn.Module):
def forward(self, x):
self.saved = x
return F.relu(x)
def backward_hook(self, module, grad_in, grad_out):
positive_grad = torch.clamp(grad_out[0], min=0.0)
positive_mask = (self.saved > 0).float()
return (positive_grad * positive_mask,)
def guided_backprop(model_ctor, weights, img_tensor, target_class=None, label="model"):
# Create a fresh copy to freely patch ReLUs
gb_model = model_ctor().to(device).eval()
# Swap all ReLUs
relus = []
for name, module in gb_model.named_modules():
if isinstance(module, nn.ReLU):
relu = GuidedBackpropReLU()
relus.append(relu)
parent = gb_model
*parents, leaf = name.split('.')
for p in parents:
parent = getattr(parent, p)
setattr(parent, leaf, relu)
x = img_tensor.clone().requires_grad_(True)
logits = gb_model(x)
if target_class is None:
target_class = logits.argmax(dim=1).item()
loss = logits[0, target_class]
gb_model.zero_grad()
hooks = [relu.register_full_backward_hook(relu.backward_hook) for relu in relus]
loss.backward()
for h in hooks: h.remove()
g = x.grad.detach()[0]
g = (g - g.min())/(g.max()-g.min()+1e-8)
g = g.permute(1,2,0).cpu().numpy()
plt.figure(figsize=(6,6)); plt.imshow(g); plt.axis('off'); plt.title(f"{label}: guided backprop")
plt.show()
return g
In [19]:
class FeatExtractor(nn.Module):
"""Return a fixed-dim feature vector (penultimate-ish) for each arch."""
def __init__(self, model, arch):
super().__init__()
self.arch = arch
self.model = model
if arch == "resnet":
# body up to layer4 GAP
self.body = nn.Sequential(
model.conv1, model.bn1, model.relu, model.maxpool,
model.layer1, model.layer2, model.layer3, model.layer4,
nn.AdaptiveAvgPool2d((1,1))
)
self.out_dim = model.fc.in_features
elif arch == "vgg" or arch == "alexnet":
self.features = model.features
self.pool = nn.AdaptiveAvgPool2d((7,7)) # match VGG/Alex input to classifier
# classifier: take everything except final Linear
self.prefix = nn.Sequential(*list(model.classifier.children())[:-1])
# out_dim is the in_features of final Linear
last_linear = list(model.classifier.children())[-1]
self.out_dim = last_linear.in_features
else:
raise ValueError("Unknown arch")
def forward(self, x):
if self.arch == "resnet":
x = self.body(x).flatten(1)
return x
else:
x = self.features(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.prefix(x)
return x
In [20]:
def max_activating_images(model, dataset, layer_name, topk=16, label="model"):
target = resolve_module(model, layer_name)
acts = []
imgs_cache = []
def fhook(m, i, o):
if o.ndim == 4:
a = o.detach().cpu().mean(dim=(2,3)) # GAP over H,W → [B, C]
else:
a = o.detach().cpu()
acts.append(a)
h = target.register_forward_hook(fhook)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2)
with torch.no_grad():
for xb, yb in loader:
imgs_cache.append(xb)
_ = model(xb.to(device))
h.remove()
A = torch.cat(acts, 0).numpy() # [N, C]
imgs_cache = torch.cat(imgs_cache, 0)
# Choose an arbitrary channel to inspect (customize this)
channel = min(5, A.shape[1]-1)
idxs = np.argsort(-A[:, channel])[:topk]
grid = imgs_cache[idxs]
# unnormalize for viewing (ImageNet stats)
grid = grid*torch.tensor([0.229,0.224,0.225])[None,:,None,None] + torch.tensor([0.485,0.456,0.406])[None,:,None,None]
grid = grid.clamp(0,1)
show(grid, title=f"{label}: top-{topk} images for channel {channel} @ {layer_name}")
In [21]:
img_path = "pishi.png"
from PIL import Image
models = {}
for name, cfg in MODELS.items():
m = cfg.ctor()
models[name] = SimpleNamespace(
model=m, weights=cfg.weights, act_layers=cfg.act_layers,
maxact_layer=cfg.maxact_layer, arch=cfg.arch,
idx_to_label=cfg.weights.meta.get("categories", None)
)
# Ensure the input image is resized to 224x224
images = {name: load_image(img_path, cfg.weights) for name, cfg in models.items()}
for name, img in images.items():
assert img.shape[-2:] == (224, 224), f"Image for model {name} is not resized to 224x224"
# 1) First-layer filters comparison
for name, cfg in models.items():
visualize_first_layer_filters(cfg.model, max_filters=64, label=name)
# 2) Activation maps at key layers
for name, cfg in models.items():
visualize_activations(cfg.model, images[name], cfg.act_layers, label=name)
# 3) Occlusion sensitivity (same target class per model by default)
for name, cfg in models.items():
_ = occlusion_heatmap(cfg.model, images[name], idx_to_label=cfg.idx_to_label, patch=32, stride=16, label=name)
# 4) Saliency and Guided Backprop
for name, cfg in models.items():
_ = saliency_map(cfg.model, images[name], label=name)
_ = guided_backprop(MODELS[name].ctor, cfg.weights, images[name], label=name)