DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports
Events Video Library
Refcards
Trend Reports

Events

View Events Video Library

Zones

Culture and Methodologies Agile Career Development Methodologies Team Management
Data Engineering AI/ML Big Data Data Databases IoT
Software Design and Architecture Cloud Architecture Containers Integration Microservices Performance Security
Coding Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
Culture and Methodologies
Agile Career Development Methodologies Team Management
Data Engineering
AI/ML Big Data Data Databases IoT
Software Design and Architecture
Cloud Architecture Containers Integration Microservices Performance Security
Coding
Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance
Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks

Modernize your data layer. Learn how to design cloud-native database architectures to meet the evolving demands of AI and GenAI workkloads.

Secure your stack and shape the future! Help dev teams across the globe navigate their software supply chain security challenges.

Releasing software shouldn't be stressful or risky. Learn how to leverage progressive delivery techniques to ensure safer deployments.

Avoid machine learning mistakes and boost model performance! Discover key ML patterns, anti-patterns, data strategies, and more.

Related

  • Banking Fraud Prevention With DeepSeek AI and AI Explainability
  • Explainable AI (XAI): How Developers Build Trust and Transparency in AI Systems
  • Demystifying the Magic: A Look Inside the Algorithms of Speech Recognition
  • The Future of AI: Exploring Generative Systems and Large Language Models

Trending

  • A Guide to Container Runtimes
  • How the Go Runtime Preempts Goroutines for Efficient Concurrency
  • The Modern Data Stack Is Overrated — Here’s What Works
  • Unlocking AI Coding Assistants Part 2: Generating Code
  1. DZone
  2. Data Engineering
  3. AI/ML
  4. AI's Dilemma: When to Retrain and When to Unlearn?

AI's Dilemma: When to Retrain and When to Unlearn?

A high-level overview of retraining and machine unlearning that includes sample code, a comparison of when to use each approach, and a conclusion.

By 
Prasannakumar Patil user avatar
Prasannakumar Patil
·
May. 09, 25 · Analysis
Likes (4)
Comment
Save
Tweet
Share
1.2K Views

Join the DZone community and get the full member experience.

Join For Free

A Growing Need for Data Privacy Solutions

In recent times, data privacy has become a central focus, with laws such as the General Data Protection Regulation (GDPR) and the California Consumer Privacy Act (CCPA) playing a key role.

Organizations are under increasing pressure to comply with user data deletion requests. One significant requirement is the right to data deletion. Where user can request their personal information to be removed from company’s databases.

Earlier, when data is removed from the datasets, companies retrain their machine learning model to ensure compliance but, machine unlearning offers an innovative approach to selectively and effectively remove unwanted data without retraining from scratch.

What Is Retraining?

Retraining is a traditional method where a data model is completely rebuilt after the data is changed. This involves rebuilding the machine learning model from scratch using a modified dataset. The entire model is reconstructed using the remaining valid data. 

How Retraining Works

  1. Removes the unwanted data from the datasets 
  2. Re-initialize the model’s parameters
  3. Updating the dataset by training from scratch 

Why Retraining Works

  1. High accuracy: Since the entire dataset is trained from scratch, models are optimized for the current data.
  2. Data integrity: Deleted datasets will not have any influence on the predictions.

Challenges in Retraining

  1. Expensive: For larger models, the computational cost increases.
  2. Time consumption: Training the complex NN (neural network) will take hours or days.
  3. Resource heavy: At scale, high energy and memory requirements will make it inefficient.

Code Example for Retraining

Sample code for retraining: first, train the model, then remove the data and retrain the model.

Python
 
import numpy as np

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

from sklearn.linear_model import LogisticRegression

from sklearn.preprocessing import StandardScaler

# Original Dataset 

X = np.array([[1, 2], [2, 3], [4, 5], [6, 7], [8, 9]])

y = np.array([0, 0, 1, 1, 1])

# split train test 

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, stratify=y, random_state=42)

# feature scaling

scaler = StandardScaler()

X_train_scaled = scaler.fit_transform(X_train)

X_test_scaled = scaler.transform(X_test)

# Before Retraining

model_before = LogisticRegression(solver='lbfgs', max_iter=200, class_weight='balanced')

model_before.fit(X_train_scaled, y_train)

# Ensuring atleast one sample for both

X_train_reduced = []

y_train_reduced = []

# Track the number of samples per class

class_counts = {0: 0, 1: 0} 

for i in range(len(y_train)):

    if class_counts[y_train[i]] < 2:

        X_train_reduced.append(X_train_scaled[i])

        y_train_reduced.append(y_train[i])

        class_counts[y_train[i]] += 1

X_train_reduced = np.array(X_train_reduced)

y_train_reduced = np.array(y_train_reduced)

# Retraining with reduced dataset

model_after = LogisticRegression(solver='lbfgs', max_iter=200, class_weight='balanced')

model_after.fit(X_train_reduced, y_train_reduced)

# Accuracies for both

print(f"Initial Model Accuracy: {model_before.score(X_test_scaled, y_test):.2f}")

print(f"Model Accuracy After Retraining: {model_after.score(X_test_scaled, y_test):.2f}")
Markdown
 
Initial Model Accuracy: 0.67 
Model Accuracy After Retraining: 0.67


What Is Machine Unlearning?

Machine unlearning is a technique in machine learning that focuses on selectively removing the influence of specific data points from the trained model. 

This process ensures the model forgets specific data points without the need to completely retrain from scratch.

How It Works

  1. Selective data removal: Identify the data points that need to be forgotten/removed from the model.
  2. Gradient reversal: Undo the effect of specific training examples by retrieving the contributions during model updates.
  3. Efficient updates: Instead of retraining the entire model from scratch, uses a specialized technique to update only parts of the model that is affected by removal of data.

Why It Works

  1. Efficient compliance: Avoids retraining the entire mode.
  2. Data privacy: Helps the organization to meet privacy requirements by removing the traces of specific user data.
  3. Cost-efficient: Time and infra costs are reduced as data is not entirely retrained from scratch

Challenges in Machine Unlearning

  1. Complexity: Implementing unlearning algorithms can be complex, especially for deep learning models.
  2. Performance: If the model is not managed properly, it may degrade the performance.
  3. Data dependencies: Challenging to remove interconnected data without affecting the accuracy
  4. Security concerns: Data traces from the model, if not ensured, might lead to security concerns.

Code Example for Machine Unlearning (Selective Data Forgetting)

This code is an approximate demonstration, as there is no production-ready “machine unlearning” library to plug in, it illustrates the concept rather than a complete solution for all scenarios.

Python
 
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# helper function get
def get_flat_params(model):
    """Flatten all the model parameters into a single vector"""
    parameters = []
    for parameter in model.parameters():
        parameters.append(parameter.view(-1))
    return torch.cat(parameters)

# helper function for set
def set_flat_params(model, flat_params):
    """Set model parameters from a flattened vector"""
    offset = 0
    for parameter in model.parameters():
        numel = parameter.numel()
        parameter.data.copy_(flat_params[offset:offset+numel].view(parameter.size()))
        offset += numel

# LR Model
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# Datasets
X_np = np.array([[1, 2],[2, 3],[3, 4],[4, 5],[5, 6],[6, 7],[7, 8]], dtype=np.float32)
y_np = np.array([0, 1, 0, 1, 0, 1, 0], dtype=np.float32).reshape(-1, 1)

X = torch.tensor(X_np)
y = torch.tensor(y_np)

# Initialization and Training
model = LogisticRegressionModel(input_dim=2)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

# Before unlearning
with torch.no_grad():
    preds = (model(X) > 0.5).float()
    accuracy_before = (preds == y).float().mean().item()
    print("Accuracy before unlearning: {:.2f}".format(accuracy_before))

# Hessian and Its Inverse
loss_full = criterion(model(X), y)
grad_params = torch.autograd.grad(loss_full, model.parameters(), create_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grad_params])
n_params = grad_flat.numel()
Hessian = torch.zeros((n_params, n_params))
for i in range(n_params):
    grad2 = torch.autograd.grad(grad_flat[i], model.parameters(), retain_graph=True)
    grad2_flat = torch.cat([g.contiguous().view(-1) for g in grad2])
    Hessian[i] = grad2_flat
Hessian_inv = torch.inverse(Hessian)

# Influence for a Specific Data Point
index_to_remove = 2  # Data point [3, 4] with label 0
x_remove = X[index_to_remove:index_to_remove+1]
y_remove = y[index_to_remove:index_to_remove+1]
loss_remove = criterion(model(x_remove), y_remove)
grad_remove = torch.autograd.grad(loss_remove, model.parameters())
grad_remove_flat = torch.cat([g.view(-1) for g in grad_remove])
n = X.size(0)
delta_theta = - (1.0/n) * torch.matmul(Hessian_inv, grad_remove_flat)

# Update the Model Parameters
flat_params_new = get_flat_params(model) + delta_theta
set_flat_params(model, flat_params_new)

# Evaluate after unlearning
with torch.no_grad():
    preds_after = (model(X) > 0.5).float()
    accuracy_after = (preds_after == y).float().mean().item()
    print("Accuracy after unlearning: {:.2f}".format(accuracy_after))
Markdown
 
Accuracy before unlearning: 0.57
Accuracy after unlearning: 0.57


Comparison

Feature Machine Unlearning Retraining

Purpose

Removes specific data points quickly

Rebuilds the model from scratch

Speed

Fast

Slow

Accuracy 

Low compared to retraining 

High compared to machine unlearning 

Resource Usage

Low

High 


Conclusion

Machine unlearning can be useful for quickly removing a specific set of data points, particularly for real-time adjustments and privacy. Retraining will ensure the accuracy and is useful if the change in a dataset is significant. The choice between unlearning and retraining depends on the goal; unlearning for quicker adaptability and compliance, and retraining is for long-term accuracy. 

AI Deep learning Machine learning

Opinions expressed by DZone contributors are their own.

Related

  • Banking Fraud Prevention With DeepSeek AI and AI Explainability
  • Explainable AI (XAI): How Developers Build Trust and Transparency in AI Systems
  • Demystifying the Magic: A Look Inside the Algorithms of Speech Recognition
  • The Future of AI: Exploring Generative Systems and Large Language Models

Partner Resources

×

Comments

The likes didn't load as expected. Please refresh the page and try again.

ABOUT US

  • About DZone
  • Support and feedback
  • Community research
  • Sitemap

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Core Program
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 3343 Perimeter Hill Drive
  • Suite 100
  • Nashville, TN 37211
  • support@dzone.com

Let's be friends: