Lecture 18 – CS 189, Fall 2025

In [1]:
import os, math, itertools, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from types import SimpleNamespace
from torchvision.utils import make_grid
from sklearn.manifold import TSNE


try:
    import torchvision as tv
    from torchvision import transforms
except Exception as e:
    torchvision = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cpu
In [2]:
# Load the dataset
if tv is None:
    print("torchvision not available. Please install torchvision to run MNIST demos.")
else:
    transform = transforms.Compose([transforms.ToTensor()])  # Keep MNIST in [0,1], single-channel
    train_ds = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_ds  = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False)
    print('Train size:', len(train_ds), ' Test size:', len(test_ds))
Train size: 60000  Test size: 10000
In [3]:
# Peek at MNIST
import matplotlib.pyplot as plt

if tv is not None:
    imgs = [train_ds[i][0] for i in range(16)]
    labels = [train_ds[i][1] for i in range(16)]
    fig, axes = plt.subplots(4,4, figsize=(4,4))
    for ax, img, lab in zip(axes.flatten(), imgs, labels):
        ax.imshow(img[0].numpy(), cmap='gray')
        ax.set_title(str(lab))
        ax.axis('off')
    plt.tight_layout(); plt.show()
No description has been provided for this image
In [4]:
# Identify the first 10 images of the digit 9
imgs_9 = [train_ds[i][0] for i in range(len(train_ds)) if train_ds[i][1] == 9][:10]

# Plot the images
fig, axes = plt.subplots(1, 10, figsize=(15, 5))
for i, img in enumerate(imgs_9):
    ax = axes[i]
    ax.imshow(img[0].numpy(), cmap='gray')  
    ax.axis('off')
    ax.set_title("9")
plt.tight_layout()
plt.show()
No description has been provided for this image
In [5]:
x = imgs_9[4].unsqueeze(0)  

# Define convolution kernels
kernels = {
    'horizontal line': torch.tensor([[[-1, -1, -1], [2, 2, 2], [-1, -1, -1]]], dtype=torch.float32).unsqueeze(0),
    'vertical line': torch.tensor([[[-1, 2, -1], [-1, 2, -1], [-1, 2, -1]]], dtype=torch.float32).unsqueeze(0),
    'diagonal line': torch.tensor([[[2, -1, -1], [-1, 2, -1], [-1, -1, 2]]], dtype=torch.float32).unsqueeze(0)
}

# Apply convolutions and visualize
fig, axes = plt.subplots(len(kernels), 3, figsize=(9, len(kernels) * 3))

for i, (name, kernel) in enumerate(kernels.items()):
    # Display the kernel
    axes[i, 0].imshow(kernel.squeeze().detach().numpy(), cmap='gray')
    axes[i, 0].set_title(f'Kernel: {name}')
    axes[i, 0].axis('off')

    # Display the input image
    axes[i, 1].imshow(x.squeeze().detach().numpy(), cmap='gray')
    axes[i, 1].set_title('Input Image')
    axes[i, 1].axis('off')

    # Display the result of the convolution
    y = F.conv2d(x, kernel, padding=1)
    axes[i, 2].imshow(y.squeeze().detach().numpy(), cmap='gray')
    axes[i, 2].set_title('Convolved Output')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image
In [6]:
if tv is not None:
    x, y = [train_ds[i] for i in range(len(train_ds)) if train_ds[i][1] == 9][4]   # x: [1,28,28]
    x = x.unsqueeze(0)  # [1,1,28,28]
    
    # Define classic kernels (normalized where reasonable)
    kernels = {
        'identity': torch.tensor([[0,0,0],[0,1,0],[0,0,0]], dtype=torch.float32),
        'edge_h':  torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32),
        'edge_v':  torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32),
        'sharpen': torch.tensor([[0,-1,0],[-1,5,-1],[0,-1,0]], dtype=torch.float32),
        'box_blur': (1/9.0)*torch.ones((3,3), dtype=torch.float32)
    }

    fig, axes = plt.subplots(2, 3, figsize=(6,4))
    axes = axes.flatten()

    axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title('Input'); axes[0].axis('off')

    i = 1
    for name, K in kernels.items():
        W = K.view(1,1,3,3)
        yk = F.conv2d(x, W, padding=1)  # keep size with padding=1
        axes[i].imshow(yk[0,0].detach().numpy(), cmap='gray')
        axes[i].set_title(name); axes[i].axis('off')
        i += 1
        if i >= len(axes):
            break

    plt.tight_layout(); plt.show()
else:
    print("torchvision not available.")
No description has been provided for this image
In [7]:
if tv is not None:
    x, _ = train_ds[1]
    x = x.unsqueeze(0)
    k = 3
    W = torch.ones(1,1,k,k) / (k*k)

    configs = [(1,0), (1,1), (2,0), (2,1)]  # (stride, padding)
    fig, axes = plt.subplots(1, len(configs)+1, figsize=(3*(len(configs)+1), 3))
    axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title('Input'); axes[0].axis('off')

    for i, (s,p) in enumerate(configs, start=1):
        y = F.conv2d(x, W, stride=s, padding=p)
        axes[i].imshow(y[0,0].detach().numpy(), cmap='gray')
        axes[i].set_title(f's={s}, p={p}\n{tuple(y.shape)}'); axes[i].axis('off')

    plt.tight_layout() 
    plt.show()
No description has been provided for this image
In [8]:
if tv is not None:
    x, _ = train_ds[2]
    x = x.unsqueeze(0)
    maxpool = nn.MaxPool2d(2,2)
    avgpool = nn.AvgPool2d(4,4)

    y_max = maxpool(x)
    y_avg = avgpool(x)

    fig, axes = plt.subplots(1, 3, figsize=(9,3))
    axes[0].imshow(x[0,0].numpy(), cmap='gray'); axes[0].set_title(f'd={len(x[0,0])}x{len(x[0,0])}\n Input'); axes[0].axis('off')
    axes[1].imshow(y_max[0,0].detach().numpy(), cmap='gray'); axes[1].set_title(f'd={len(y_max[0,0])}x {len(y_max[0,0])}\n MaxPool 2x2'); axes[1].axis('off')
    axes[2].imshow(y_avg[0,0].detach().numpy(), cmap='gray'); axes[2].set_title(f'd={len(y_avg[0,0])}x {len(y_avg[0,0])}\n AvgPool 4x4'); axes[2].axis('off')
    plt.tight_layout(); plt.show()
No description has been provided for this image
In [9]:
class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),  
            nn.ReLU(),
            nn.MaxPool2d(2),                 
            nn.Conv2d(8, 64, 3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2),               
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*7*7, 256),  # Adjusted based on the output size of the features block
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 10)
        )
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

model = SmallCNN().to(device)
summary(model, input_size=(1, 1, 28, 28),  # (batch, C, H, W)
        col_names=("input_size","output_size","num_params","kernel_size"),
        depth=4)
Out[9]:
============================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
============================================================================================================================================
SmallCNN                                 [1, 1, 28, 28]            [1, 10]                   --                        --
├─Sequential: 1-1                        [1, 1, 28, 28]            [1, 64, 7, 7]             --                        --
│    └─Conv2d: 2-1                       [1, 1, 28, 28]            [1, 8, 28, 28]            80                        [3, 3]
│    └─ReLU: 2-2                         [1, 8, 28, 28]            [1, 8, 28, 28]            --                        --
│    └─MaxPool2d: 2-3                    [1, 8, 28, 28]            [1, 8, 14, 14]            --                        2
│    └─Conv2d: 2-4                       [1, 8, 14, 14]            [1, 64, 14, 14]           4,672                     [3, 3]
│    └─ReLU: 2-5                         [1, 64, 14, 14]           [1, 64, 14, 14]           --                        --
│    └─MaxPool2d: 2-6                    [1, 64, 14, 14]           [1, 64, 7, 7]             --                        2
├─Sequential: 1-2                        [1, 64, 7, 7]             [1, 10]                   --                        --
│    └─Flatten: 2-7                      [1, 64, 7, 7]             [1, 3136]                 --                        --
│    └─Linear: 2-8                       [1, 3136]                 [1, 256]                  803,072                   --
│    └─ReLU: 2-9                         [1, 256]                  [1, 256]                  --                        --
│    └─Dropout: 2-10                     [1, 256]                  [1, 256]                  --                        --
│    └─Linear: 2-11                      [1, 256]                  [1, 10]                   2,570                     --
============================================================================================================================================
Total params: 810,394
Trainable params: 810,394
Non-trainable params: 0
Total mult-adds (M): 1.78
============================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.15
Params size (MB): 3.24
Estimated Total Size (MB): 3.40
============================================================================================================================================
In [10]:
def train_one_epoch(model, loader, opt, loss_fn):
    model.train()
    total, correct, running_loss = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        opt.step()
        running_loss += loss.item()*xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds==yb).sum().item()
        total += xb.size(0)
    return running_loss/total, correct/total

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    total, correct, running_loss = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = loss_fn(logits, yb)
        running_loss += loss.item()*xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds==yb).sum().item()
        total += xb.size(0)
    return running_loss/total, correct/total

history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
if tv is not None:
    model = SmallCNN().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    EPOCHS = 4 
    for epoch in range(1, EPOCHS+1):
        tl, ta = train_one_epoch(model, train_loader, opt, loss_fn)
        vl, va = evaluate(model, test_loader, loss_fn)
        history['epoch'].append(epoch)
        history['train_loss'].append(tl); history['val_loss'].append(vl)
        history['train_acc'].append(ta);  history['val_acc'].append(va)
        print(f'E{epoch}: train_loss={tl:.4f} val_loss={vl:.4f} train_acc={ta:.3f} val_acc={va:.3f}')
else:
    print("torchvision not available. Skipping training.")
E1: train_loss=0.2968 val_loss=0.0817 train_acc=0.910 val_acc=0.975
E2: train_loss=0.0768 val_loss=0.0519 train_acc=0.977 val_acc=0.982
E3: train_loss=0.0545 val_loss=0.0390 train_acc=0.983 val_acc=0.986
E4: train_loss=0.0430 val_loss=0.0330 train_acc=0.987 val_acc=0.989
In [11]:
if history and 'train_loss' in history and 'val_loss' in history:
    plt.figure(); plt.plot(history['epoch'], history['train_loss']); plt.plot(history['epoch'], history['val_loss']); plt.legend(['train','val']); plt.title('Loss'); plt.xlabel('epoch'); plt.show()
if history and 'train_acc' in history and 'val_acc' in history:
    plt.figure(); plt.plot(history['epoch'], history['train_acc']); plt.plot(history['epoch'], history['val_acc']); plt.legend(['train','val']); plt.title('Accuracy'); plt.xlabel('epoch'); plt.show()
No description has been provided for this image
No description has been provided for this image
In [12]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np

if tv is not None:
    model.eval()
    all_true, all_pred = [], []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            preds = logits.argmax(dim=1)
            all_pred.extend(preds.cpu().numpy())
            all_true.extend(yb.cpu().numpy())

    # Calculate accuracy
    all_true = np.array(all_true)
    all_pred = np.array(all_pred)
    accuracy = (all_true == all_pred).sum() / len(all_true)
    print(f"Accuracy on test data: {accuracy:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(all_true, all_pred, labels=list(range(10)))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(range(10)))
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()
else:
    print("torchvision not available. Cannot evaluate predictions.")
Accuracy on test data: 0.9886
No description has been provided for this image
In [13]:
alex_w   = tv.models.AlexNet_Weights.IMAGENET1K_V1

MODELS = {
    "alexnet": SimpleNamespace(
        ctor=lambda: tv.models.alexnet(weights=alex_w).to(device).eval(),
        weights=alex_w,
        act_layers=["features.0", "features.3"],  # conv blocks
        maxact_layer="features.10",
        arch="alexnet"
    ),
}
In [14]:
def load_image(path, weights):
    img = Image.open(path).convert('RGB')
    return weights.transforms()(img).unsqueeze(0).to(device)

def show(tensor, title=None):
    if tensor.ndim == 4:
        grid = make_grid(tensor, nrow=int(np.ceil(np.sqrt(tensor.size(0)))))
        arr = grid.permute(1,2,0).detach().cpu().numpy()
    else:
        arr = tensor.permute(1,2,0).detach().cpu().numpy()
    plt.figure(figsize=(6,6))
    plt.imshow(np.clip(arr, 0, 1))
    plt.axis('off')
    if title: plt.title(title)
    plt.show()

# Robust per-item normalization (avoids tuple-dim min/max issues)
def _norm_per_item(t):
    # t shape: [N, ...]
    if hasattr(torch, "amin"):
        tmin = torch.amin(t, dim=tuple(range(1, t.ndim)), keepdim=True)
        tmax = torch.amax(t, dim=tuple(range(1, t.ndim)), keepdim=True)
    else:
        flat = t.view(t.size(0), -1)
        tmin = flat.min(dim=1, keepdim=True)[0].view(-1, *([1]*(t.ndim-1)))
        tmax = flat.max(dim=1, keepdim=True)[0].view(-1, *([1]*(t.ndim-1)))
    return (t - tmin) / (tmax - tmin + 1e-8)

# Helper: resolve "features.23" dotted path to a module
def resolve_module(root, name):
    mod = root
    for part in name.split('.'):
        if part.isdigit():
            mod = mod[int(part)]
        else:
            mod = getattr(mod, part)
    return mod

def first_conv_module(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            return m
    raise RuntimeError("No Conv2d found.")
In [15]:
def visualize_first_layer_filters(model, max_filters=64, label="model"):
    conv1 = None
    # Try common entry points first
    for attr in ["conv1", "features"]:
        if hasattr(model, attr):
            m = getattr(model, attr)
            if isinstance(m, nn.Conv2d):
                conv1 = m
                break
            # If Sequential, first layer likely Conv2d
            if isinstance(m, nn.Sequential):
                for x in m:
                    if isinstance(x, nn.Conv2d):
                        conv1 = x; break
        if conv1 is not None: break
    if conv1 is None:
        conv1 = first_conv_module(model)

    w = conv1.weight.data.clone().cpu()  # [out, in, k, k]
    w = _norm_per_item(w)
    show(w[:max_filters], title=f"{label}: first-layer conv filters")
In [16]:
def visualize_activations(model, img_tensor, layer_names, label="model"):
    feats, hooks = {}, []
    def hook(name): 
        return lambda m, i, o: feats.setdefault(name, o.detach().cpu())
    # Register hooks
    for name in layer_names:
        try:
            module = resolve_module(model, name)
            hooks.append(module.register_forward_hook(hook(name)))
        except Exception as e:
            print(f"[warn] could not hook '{name}': {e}")

    with torch.no_grad():
        _ = model(img_tensor)
    for h in hooks: h.remove()

    for name, feat in feats.items():
        fmap = feat[0]                    
        if fmap.ndim != 3:
            print(f"[info] {label}:{name} is non-spatial (shape {feat.shape}), skipping grid")
            continue
        C = min(64, fmap.size(0))
        fm = fmap[:C]
        fm = _norm_per_item(fm.unsqueeze(1)).squeeze(1)
        show(fm.unsqueeze(1), title=f"{label}: activations @ {name}")
In [17]:
@torch.no_grad()
def predict_probs(model, x):
    logits = model(x)
    return F.softmax(logits, dim=1)

def _resize_heat_with_torch(heat, H, W):
    t = torch.from_numpy(heat)[None, None]
    t = F.interpolate(t.float(), size=(H, W), mode="bilinear", align_corners=False)
    return t[0,0].numpy()

def occlusion_heatmap(model, img_tensor, idx_to_label=None, target_class=None, patch=32, stride=16, baseline=0.0, label="model"):
    model.eval()
    x = img_tensor.clone()
    probs = predict_probs(model, x)[0]
    if target_class is None:
        target_class = probs.argmax().item()
    base_p = probs[target_class].item()

    _, _, H, W = x.shape
    heat = np.zeros(((H - patch)//stride + 1, (W - patch)//stride + 1), dtype=np.float32)

    for i, y in enumerate(range(0, H - patch + 1, stride)):
        for j, z in enumerate(range(0, W - patch + 1, stride)):
            x_ = x.clone()
            x_[:,:, y:y+patch, z:z+patch] = baseline
            p = predict_probs(model, x_)[0, target_class].item()
            heat[i, j] = base_p - p

    heat_resized = _resize_heat_with_torch(heat, H, W)
    # quick unnormalize for show (using Imagenet stats)
    im = x[0].detach().cpu()
    im = (im * torch.tensor([0.229,0.224,0.225])[:,None,None] + torch.tensor([0.485,0.456,0.406])[:,None,None]).permute(1,2,0).numpy()
    plt.figure(figsize=(6,6)); plt.imshow(np.clip(im,0,1)); plt.imshow(heat_resized, alpha=0.5); plt.axis('off')
    if idx_to_label:
        tname = idx_to_label[target_class]
    else:
        tname = str(target_class)
    plt.title(f"{label}: occlusion (target='{tname}', base p={base_p:.3f})")
    plt.show()
    return heat_resized
In [18]:
def saliency_map(model, img_tensor, target_class=None, label="model"):
    model.eval()
    x = img_tensor.clone().requires_grad_(True)
    logits = model(x)
    if target_class is None:
        target_class = logits.argmax(dim=1).item()
    loss = logits[0, target_class]
    model.zero_grad()
    loss.backward()
    g = x.grad.detach()[0]               # [3,H,W]
    sal = g.abs().max(dim=0)[0]          # [H,W]
    sal = (sal - sal.min())/(sal.max()-sal.min()+1e-8)
    plt.figure(figsize=(6,6)); plt.imshow(sal.cpu(), cmap='gray'); plt.axis('off'); plt.title(f"{label}: saliency")
    plt.show()
    return sal

class GuidedBackpropReLU(nn.Module):
    def forward(self, x):
        self.saved = x
        return F.relu(x)
    def backward_hook(self, module, grad_in, grad_out):
        positive_grad = torch.clamp(grad_out[0], min=0.0)
        positive_mask = (self.saved > 0).float()
        return (positive_grad * positive_mask,)

def guided_backprop(model_ctor, weights, img_tensor, target_class=None, label="model"):
    # Create a fresh copy to freely patch ReLUs
    gb_model = model_ctor().to(device).eval()
    # Swap all ReLUs
    relus = []
    for name, module in gb_model.named_modules():
        if isinstance(module, nn.ReLU):
            relu = GuidedBackpropReLU()
            relus.append(relu)
            parent = gb_model
            *parents, leaf = name.split('.')
            for p in parents:
                parent = getattr(parent, p)
            setattr(parent, leaf, relu)
    x = img_tensor.clone().requires_grad_(True)
    logits = gb_model(x)
    if target_class is None:
        target_class = logits.argmax(dim=1).item()
    loss = logits[0, target_class]
    gb_model.zero_grad()
    hooks = [relu.register_full_backward_hook(relu.backward_hook) for relu in relus]
    loss.backward()
    for h in hooks: h.remove()

    g = x.grad.detach()[0]
    g = (g - g.min())/(g.max()-g.min()+1e-8)
    g = g.permute(1,2,0).cpu().numpy()
    plt.figure(figsize=(6,6)); plt.imshow(g); plt.axis('off'); plt.title(f"{label}: guided backprop")
    plt.show()
    return g
In [19]:
class FeatExtractor(nn.Module):
    """Return a fixed-dim feature vector (penultimate-ish) for each arch."""
    def __init__(self, model, arch):
        super().__init__()
        self.arch = arch
        self.model = model
        if arch == "resnet":
            # body up to layer4 GAP
            self.body = nn.Sequential(
                model.conv1, model.bn1, model.relu, model.maxpool,
                model.layer1, model.layer2, model.layer3, model.layer4,
                nn.AdaptiveAvgPool2d((1,1))
            )
            self.out_dim = model.fc.in_features
        elif arch == "vgg" or arch == "alexnet":
            self.features = model.features
            self.pool = nn.AdaptiveAvgPool2d((7,7))  # match VGG/Alex input to classifier
            # classifier: take everything except final Linear
            self.prefix = nn.Sequential(*list(model.classifier.children())[:-1])
            # out_dim is the in_features of final Linear
            last_linear = list(model.classifier.children())[-1]
            self.out_dim = last_linear.in_features
        else:
            raise ValueError("Unknown arch")
    def forward(self, x):
        if self.arch == "resnet":
            x = self.body(x).flatten(1)
            return x
        else:
            x = self.features(x)
            x = self.pool(x)
            x = torch.flatten(x, 1)
            x = self.prefix(x)
            return x
In [20]:
def max_activating_images(model, dataset, layer_name, topk=16, label="model"):
    target = resolve_module(model, layer_name)
    acts = []
    imgs_cache = []
    def fhook(m, i, o):
        if o.ndim == 4:
            a = o.detach().cpu().mean(dim=(2,3))  # GAP over H,W → [B, C]
        else:
            a = o.detach().cpu()
        acts.append(a)
    h = target.register_forward_hook(fhook)
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2)
    with torch.no_grad():
        for xb, yb in loader:
            imgs_cache.append(xb)
            _ = model(xb.to(device))
    h.remove()
    A = torch.cat(acts, 0).numpy()      # [N, C]
    imgs_cache = torch.cat(imgs_cache, 0)
    # Choose an arbitrary channel to inspect (customize this)
    channel = min(5, A.shape[1]-1)
    idxs = np.argsort(-A[:, channel])[:topk]
    grid = imgs_cache[idxs]
    # unnormalize for viewing (ImageNet stats)
    grid = grid*torch.tensor([0.229,0.224,0.225])[None,:,None,None] + torch.tensor([0.485,0.456,0.406])[None,:,None,None]
    grid = grid.clamp(0,1)
    show(grid, title=f"{label}: top-{topk} images for channel {channel} @ {layer_name}")
In [21]:
img_path = "pishi.png" 
from PIL import Image

models = {}
for name, cfg in MODELS.items():
    m = cfg.ctor()
    models[name] = SimpleNamespace(
        model=m, weights=cfg.weights, act_layers=cfg.act_layers,
        maxact_layer=cfg.maxact_layer, arch=cfg.arch,
        idx_to_label=cfg.weights.meta.get("categories", None)
    )

# Ensure the input image is resized to 224x224
images = {name: load_image(img_path, cfg.weights) for name, cfg in models.items()}
for name, img in images.items():
    assert img.shape[-2:] == (224, 224), f"Image for model {name} is not resized to 224x224"

# 1) First-layer filters comparison
for name, cfg in models.items():
    visualize_first_layer_filters(cfg.model, max_filters=64, label=name)

# 2) Activation maps at key layers
for name, cfg in models.items():
    visualize_activations(cfg.model, images[name], cfg.act_layers, label=name)

# 3) Occlusion sensitivity (same target class per model by default)
for name, cfg in models.items():
    _ = occlusion_heatmap(cfg.model, images[name], idx_to_label=cfg.idx_to_label, patch=32, stride=16, label=name)

# 4) Saliency and Guided Backprop
for name, cfg in models.items():
    _ = saliency_map(cfg.model, images[name], label=name)
    _ = guided_backprop(MODELS[name].ctor, cfg.weights, images[name], label=name)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image