Using PyTorch to Investigate Catastrophic Forgetting in Continual Learning

I’ve been working on this for awhile. I want to start writing more about Pytorch. One topic that has been taking a lot of my reading time these days is catastrophic forgetting. Lets dive into it. Catastrophic forgetting is a well-documented failure mode in artificial neural networks where previously learned knowledge is rapidly overwritten when a model is trained on new tasks. This phenomenon presents a major obstacle for systems intended to perform continual or lifelong learning. While human learning appears to consolidate past experiences in ways that allow for incremental acquisition of new knowledge (a huge fucking maybe here btw, in fact, a lot of this is a deep maybe), deep learning systems—especially those trained using stochastic gradient descent—lack native mechanisms for preserving older knowledge. In this article, we explore how PyTorch can be used to simulate and mitigate this effect using a controlled experiment involving disjoint classification tasks and a technique called Elastic Weight Consolidation (EWC).

Why do you care? Recently at work, my boss had to explain in detail how my company makes sure that there is no data left if processing nodes are reused. That really got me thinking about this…

We construct a continual learning environment using MNIST by creating two disjoint tasks: one involving classification of digits 0 through 4 and another involving digits 5 through 9. The dataset is filtered using torchvision utilities to extract samples corresponding to each task. A shared multilayer perceptron model is defined in PyTorch using two fully connected hidden layers followed by a single classification head, allowing us to isolate the effects of sequential training on a common representation space. The model is first trained exclusively on Task A using standard cross-entropy loss and Adam optimization. Performance is evaluated on Task A using a held-out test set. Following this, the model is trained on Task B without revisiting Task A, and evaluation is repeated on both tasks.

import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def init(self):
super().init()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 256)def forward(self, x):
    x = x.view(x.size(0), -1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    return self.out(x)

As expected, the model exhibits catastrophic forgetting: accuracy on Task A degrades significantly after Task B training, despite the underlying model architecture remaining unchanged. This result validates the conventional understanding that deep networks, when trained naively on non-overlapping tasks, tend to fully overwrite internal representations. To counteract this, we implement Elastic Weight Consolidation, which penalizes updates to parameters deemed important for previously learned tasks.

def compute_fisher(model, dataloader):
model.eval()
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
for x, y in dataloader:
x, y = x.to(device), y.to(device)
model.zero_grad()
out = model(x)
loss = F.cross_entropy(out, y)
loss.backward()
for n, p in model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.data.pow(2)
for n in fisher:
fisher[n] /= len(dataloader)
return fis

To apply EWC, we compute the Fisher Information Matrix for the model parameters after Task A training. This is done by accumulating the squared gradients of the loss with respect to each parameter, averaged over samples from Task A. The Fisher matrix serves as a proxy for parameter importance—those parameters with large entries are assumed to play a critical role in preserving Task A performance. When training on Task B, an additional term is added to the loss function that penalizes the squared deviation of each parameter from its Task A value, weighted by the corresponding Fisher value. This constrains the optimizer to adjust the model in a way that minimally disrupts the structure needed for the first task.

Empirical evaluation demonstrates that with EWC, the model retains significantly more performance on Task A while still acquiring Task B effectively. Without EWC, Task A accuracy drops from 94 percent to under 50 percent. With EWC, Task A accuracy remains above 88 percent, while Task B accuracy only slightly decreases compared to the unconstrained case. The exact tradeoff can be tuned using the lambda regularization hyperparameter in the EWC loss.

This experiment highlights both the limitations and the flexibility of gradient-based learning in sequential settings. While deep neural networks do not inherently preserve older knowledge, PyTorch provides the low-level control necessary to implement constraint-aware training procedures like EWC. These mechanisms approximate the role of biological consolidation processes observed in the human brain and provide a path forward for building agents that learn continuously over time.

Future directions could include applying generative replay, using dynamic architectures that grow with tasks, or experimenting with online Fisher matrix approximations to scale to longer task sequences. While Elastic Weight Consolidation is only one tool in the broader field of continual learning, it serves as a useful reference implementation for those investigating ways to mitigate the brittleness of static deep learning pipelines.

Why the hell does this matter? beyond classification accuracy and standard benchmarks, the structure of learning itself remains an open frontier—one where tools like PyTorch allow morons and nerds like me to probe and control the dynamics of plasticity and stability in artificial systems.