AI's Dilemma: When to Retrain and When to Unlearn?
A high-level overview of retraining and machine unlearning that includes sample code, a comparison of when to use each approach, and a conclusion.
Join the DZone community and get the full member experience.
Join For FreeA Growing Need for Data Privacy Solutions
In recent times, data privacy has become a central focus, with laws such as the General Data Protection Regulation (GDPR) and the California Consumer Privacy Act (CCPA) playing a key role.
Organizations are under increasing pressure to comply with user data deletion requests. One significant requirement is the right to data deletion. Where user can request their personal information to be removed from company’s databases.
Earlier, when data is removed from the datasets, companies retrain their machine learning model to ensure compliance but, machine unlearning offers an innovative approach to selectively and effectively remove unwanted data without retraining from scratch.
What Is Retraining?
Retraining is a traditional method where a data model is completely rebuilt after the data is changed. This involves rebuilding the machine learning model from scratch using a modified dataset. The entire model is reconstructed using the remaining valid data.
How Retraining Works
- Removes the unwanted data from the datasets
- Re-initialize the model’s parameters
- Updating the dataset by training from scratch
Why Retraining Works
- High accuracy: Since the entire dataset is trained from scratch, models are optimized for the current data.
- Data integrity: Deleted datasets will not have any influence on the predictions.
Challenges in Retraining
- Expensive: For larger models, the computational cost increases.
- Time consumption: Training the complex NN (neural network) will take hours or days.
- Resource heavy: At scale, high energy and memory requirements will make it inefficient.
Code Example for Retraining
Sample code for retraining: first, train the model, then remove the data and retrain the model.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
# Original Dataset
X = np.array([[1, 2], [2, 3], [4, 5], [6, 7], [8, 9]])
y = np.array([0, 0, 1, 1, 1])
# split train test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, stratify=y, random_state=42)
# feature scaling
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Before Retraining
model_before = LogisticRegression(solver='lbfgs', max_iter=200, class_weight='balanced')
model_before.fit(X_train_scaled, y_train)
# Ensuring atleast one sample for both
X_train_reduced = []
y_train_reduced = []
# Track the number of samples per class
class_counts = {0: 0, 1: 0}
for i in range(len(y_train)):
if class_counts[y_train[i]] < 2:
X_train_reduced.append(X_train_scaled[i])
y_train_reduced.append(y_train[i])
class_counts[y_train[i]] += 1
X_train_reduced = np.array(X_train_reduced)
y_train_reduced = np.array(y_train_reduced)
# Retraining with reduced dataset
model_after = LogisticRegression(solver='lbfgs', max_iter=200, class_weight='balanced')
model_after.fit(X_train_reduced, y_train_reduced)
# Accuracies for both
print(f"Initial Model Accuracy: {model_before.score(X_test_scaled, y_test):.2f}")
print(f"Model Accuracy After Retraining: {model_after.score(X_test_scaled, y_test):.2f}")
Initial Model Accuracy: 0.67
Model Accuracy After Retraining: 0.67
What Is Machine Unlearning?
Machine unlearning is a technique in machine learning that focuses on selectively removing the influence of specific data points from the trained model.
This process ensures the model forgets specific data points without the need to completely retrain from scratch.
How It Works
- Selective data removal: Identify the data points that need to be forgotten/removed from the model.
- Gradient reversal: Undo the effect of specific training examples by retrieving the contributions during model updates.
- Efficient updates: Instead of retraining the entire model from scratch, uses a specialized technique to update only parts of the model that is affected by removal of data.
Why It Works
- Efficient compliance: Avoids retraining the entire mode.
- Data privacy: Helps the organization to meet privacy requirements by removing the traces of specific user data.
- Cost-efficient: Time and infra costs are reduced as data is not entirely retrained from scratch
Challenges in Machine Unlearning
- Complexity: Implementing unlearning algorithms can be complex, especially for deep learning models.
- Performance: If the model is not managed properly, it may degrade the performance.
- Data dependencies: Challenging to remove interconnected data without affecting the accuracy
- Security concerns: Data traces from the model, if not ensured, might lead to security concerns.
Code Example for Machine Unlearning (Selective Data Forgetting)
This code is an approximate demonstration, as there is no production-ready “machine unlearning” library to plug in, it illustrates the concept rather than a complete solution for all scenarios.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# helper function get
def get_flat_params(model):
"""Flatten all the model parameters into a single vector"""
parameters = []
for parameter in model.parameters():
parameters.append(parameter.view(-1))
return torch.cat(parameters)
# helper function for set
def set_flat_params(model, flat_params):
"""Set model parameters from a flattened vector"""
offset = 0
for parameter in model.parameters():
numel = parameter.numel()
parameter.data.copy_(flat_params[offset:offset+numel].view(parameter.size()))
offset += numel
# LR Model
class LogisticRegressionModel(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.linear = nn.Linear(input_dim, 1)
def forward(self, x):
return torch.sigmoid(self.linear(x))
# Datasets
X_np = np.array([[1, 2],[2, 3],[3, 4],[4, 5],[5, 6],[6, 7],[7, 8]], dtype=np.float32)
y_np = np.array([0, 1, 0, 1, 0, 1, 0], dtype=np.float32).reshape(-1, 1)
X = torch.tensor(X_np)
y = torch.tensor(y_np)
# Initialization and Training
model = LogisticRegressionModel(input_dim=2)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
num_epochs = 1000
for epoch in range(num_epochs):
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
# Before unlearning
with torch.no_grad():
preds = (model(X) > 0.5).float()
accuracy_before = (preds == y).float().mean().item()
print("Accuracy before unlearning: {:.2f}".format(accuracy_before))
# Hessian and Its Inverse
loss_full = criterion(model(X), y)
grad_params = torch.autograd.grad(loss_full, model.parameters(), create_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grad_params])
n_params = grad_flat.numel()
Hessian = torch.zeros((n_params, n_params))
for i in range(n_params):
grad2 = torch.autograd.grad(grad_flat[i], model.parameters(), retain_graph=True)
grad2_flat = torch.cat([g.contiguous().view(-1) for g in grad2])
Hessian[i] = grad2_flat
Hessian_inv = torch.inverse(Hessian)
# Influence for a Specific Data Point
index_to_remove = 2 # Data point [3, 4] with label 0
x_remove = X[index_to_remove:index_to_remove+1]
y_remove = y[index_to_remove:index_to_remove+1]
loss_remove = criterion(model(x_remove), y_remove)
grad_remove = torch.autograd.grad(loss_remove, model.parameters())
grad_remove_flat = torch.cat([g.view(-1) for g in grad_remove])
n = X.size(0)
delta_theta = - (1.0/n) * torch.matmul(Hessian_inv, grad_remove_flat)
# Update the Model Parameters
flat_params_new = get_flat_params(model) + delta_theta
set_flat_params(model, flat_params_new)
# Evaluate after unlearning
with torch.no_grad():
preds_after = (model(X) > 0.5).float()
accuracy_after = (preds_after == y).float().mean().item()
print("Accuracy after unlearning: {:.2f}".format(accuracy_after))
Accuracy before unlearning: 0.57
Accuracy after unlearning: 0.57
Comparison
Feature | Machine Unlearning | Retraining |
---|---|---|
Purpose |
Removes specific data points quickly |
Rebuilds the model from scratch |
Speed |
Fast |
Slow |
Accuracy |
Low compared to retraining |
High compared to machine unlearning |
Resource Usage |
Low |
High |
Conclusion
Machine unlearning can be useful for quickly removing a specific set of data points, particularly for real-time adjustments and privacy. Retraining will ensure the accuracy and is useful if the change in a dataset is significant. The choice between unlearning and retraining depends on the goal; unlearning for quicker adaptability and compliance, and retraining is for long-term accuracy.
Opinions expressed by DZone contributors are their own.
Comments