DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports
Events Video Library
Refcards
Trend Reports

Events

View Events Video Library

Related

  • Comparison of Various AI Code Generation Tools
  • ChatGPT Integration With Python: Unleashing the Power of AI Conversation
  • Navigating the Complexities of AI-Driven Integration in Multi-Cloud Environments: A Veteran’s Insights
  • AI-Driven Integration in Large-Scale Agile Environments

Trending

  • LLM Integration in Enterprise Applications: A Practical Guide
  • Invisible Failures in S/4HANA Conversions (And Why Teams Miss Them)
  • What Is Plagiarism? How to Avoid It and Cite Sources
  • How AI Coding Assistants Are Changing Developer Flow
  1. DZone
  2. Data Engineering
  3. AI/ML
  4. Integrated Gradients: AI Explainability for Regulated Industries

Integrated Gradients: AI Explainability for Regulated Industries

Learn how Integrated Gradients helps identify which input features contribute most to the model's predictions to ensure transparency.

By 
Vamsi Kavuri user avatar
Vamsi Kavuri
DZone Core CORE ·
Nov. 06, 24 · Tutorial
Likes (5)
Comment
Save
Tweet
Share
20.5K Views

Join the DZone community and get the full member experience.

Join For Free

Usage of AI models in highly regulated sectors like finance and healthcare comes with a critical responsibility: explainability. It is not enough that your model predictions are accurate. You should be in a position to explain why your model made a specific prediction. For example, if we are developing a tumor detection model based on brain MRI scans, we should be able to explain what information our model uses and how it processes that information and leads to tumor identification. In this case, regulators or doctors need to know these details to ensure unbiased and accurate results. So, how do you explain your model decision? Explaining them manually is not easy as there's no simple "if-else" logic — deep learning models might often have millions of parameters interacting in non-linear ways, making it impossible to trace the path from input to output. 

One of the techniques that I have had hands-on experience with in addressing this need is Integrated Gradients. It was introduced back in 2017 by researchers at Google, a powerful method that computes attributions by integrating gradients from a baseline to the actual input. In this article, I will walk you through an image classification use case and show you how integrated gradients help us to understand what image pixels are most important in the decision. I will be using the Captum library for calculating attributions and a pre-trained ResNet model to predict images.

Environment Setup

  • Python 3.10 or higher is installed.
  • Install the necessary packages mentioned below.
Python
 
pip install captum torch torchvision Matplotlib NumPy PIL


  • Download a sample image and name it as image.jpg (you can download any image that you like).

Now, we are going to load the pre-trained ResNet model that knows about the objects, classify the image by using ResNet model, compute attributes using integrated gradients technique and visualize the results that shows which pixels of the image are most important for the model's prediction.

Below is the complete implementation with detailed comments.

Python
 
import torch
import torchvision
import torchvision.transforms as transforms
from captum.attr import IntegratedGradients
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# Load the image that we just downloaded. Change the image name, if you downloaded with different name.
image_to_predicted = Image.open('image.jpg')

# Resize the image into standarad format and convert it into numbers.
transformed_image = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])(image_to_predicted).unsqueeze(0)

# Download pre-trained ResNet model make it ready for prediction.
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# Making the predictio and using argmax function to find the object with the highest probability.
predicted_image_class = torch.nn.functional.softmax(model(transformed_image)[0], dim=0).argmax().item()

# creating IntegratedGradients object and calculate the attributes
integrated_gradiants = IntegratedGradients(model)

# Creating baseline reference from which the IG computation starts
baseline_image = torch.zeros_like(transformed_image)

# compute attributions by using predicted image class, baseline and transformed image
computed_attributions,delta = integrated_gradiants.attribute(transformed_image, baseline_image, target=predicted_image_class, return_convergence_delta=True)

# Convert attributions to numpy array for visualization
attributions_numpy = np.abs(np.transpose(computed_attributions.squeeze().cpu().detach().numpy(), (1, 2, 0))) * 255 * 10
attributions_numpy = attributions_numpy.astype(np.uint8)

# Visualize both downloaded image and image with attributions
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img)
axes[0].axis('off')
axes[0].set_title('Downloaded image')
axes[1].imshow(attributions_numpy,cmap='magma')
axes[1].axis('off')
axes[1].set_title('Image with attributions')
plt.show()


Here is the output for the sample image that I downloaded. You can see the highlighted area that shows the importance of each pixel for the model's prediction.

Downloaded image vs Image with attributions

Conclusion

In this tutorial, we have explored how Integrated Gradients can be used to explain the predictions of a deep learning model. Using Integrated Gradients, we have gained insights into which pixels of an image are most important for the model's prediction. This is not the only technique available for model explainability. Other techniques like Feature Importance, Shapley Additive Explanations (SHAP) can also be used to gain insights into model behavior.

AI Deep learning Attribute (computing) Python (language) Integration

Opinions expressed by DZone contributors are their own.

Related

  • Comparison of Various AI Code Generation Tools
  • ChatGPT Integration With Python: Unleashing the Power of AI Conversation
  • Navigating the Complexities of AI-Driven Integration in Multi-Cloud Environments: A Veteran’s Insights
  • AI-Driven Integration in Large-Scale Agile Environments

Partner Resources

×

Comments

The likes didn't load as expected. Please refresh the page and try again.

  • RSS
  • X
  • Facebook

ABOUT US

  • About DZone
  • Support and feedback
  • Community research

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Core Program
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 3343 Perimeter Hill Drive
  • Suite 215
  • Nashville, TN 37211
  • [email protected]

Let's be friends:

  • RSS
  • X
  • Facebook