Tutorial

PyTorch 101: Understanding Hooks

Updated on April 28, 2025
authorauthor

By Ayoosh Kathuria and Shaoni Mukherjee

PyTorch 101: Understanding Hooks

In this article, we will explore PyTorch Hooks — a powerful feature that allows you to visualize your model during the forward and backward passes. Hooks give you the ability to debug your training process, visualize activations, and even modify gradients without altering your model’s architecture. We are covering hooks because they are essential tools for diagnosing issues like vanishing gradients, understanding intermediate layer behaviors, and gaining fine-grained control over training dynamics. This tutorial is designed for beginner to intermediate PyTorch users who have experience building and training models and now want deeper insights into model behavior, as well as for researchers and developers looking to optimize and customize their networks. By the end of this tutorial, you’ll have the tools to open the “black box” of deep learning and better understand what happens inside your models during training.

Introduction to PyTorch Hooks

Hooks in PyTorch are severely under-documented for the functionality they bring to the table.

One of the reasons hooks have a major role in PyTorch is that they allow you to interact with your model during backpropagation. Think of a hook as a clever device — like the kind heroes plant inside a villain’s base to secretly gather information. In PyTorch, a hook is simply a function that you attach to either a Tensor or a nn.Module, and it gets executed automatically when the forward or backward pass happens.

Now, when I say “forward,” I don’t just mean the forward() method you define inside a nn.Module class. Instead, I’m referring to the internal forward operation of PyTorch’s torch.autograd.Function — which is the underlying mechanism that handles computation graphs and automatic differentiation. Every tensor that results from an operation (like addition or multiplication) has a grad_fn attached to it. This grad_fn is actually an instance of a torch.autograd.Function, responsible for creating that tensor.
For example, if you compute tens = tens1 + tens2, the resulting tensor tens will have a grad_fn of type AddBackward. This means the system internally knows how to compute the gradients when backpropagating through this addition operation.

If this explanation feels a bit confusing, I highly recommend reviewing our earlier article on computation graphs in PyTorch. However, if you just want the quick version, remember that every tensor created through operations (not manually) in PyTorch tracks how it was created through its grad_fn.

Now, here’s something important: nn.Module objects, like nn.Linear, are composed of multiple operations internally. For instance, a Linear layer computes its output using two operations: matrix multiplication followed by addition (Y = W * X + B). That means, at the autograd level, there will be multiple forward operations happening — one for multiplication and one for addition — not just a single forward() function call.
If you’re using hooks without keeping this in mind, you might accidentally hook onto each individual operation rather than the full layer, leading to multiple outputs or unexpected behavior. We’ll dive deeper into how to handle this properly later in the tutorial.

PyTorch provides two types of hooks.

  1. The Forward Hook
  2. The Backward Hook

A forward hook is executed during the forward pass, while the backward hook is, well, you guessed it, executed when the backward function is called. Time to remind you again: these are the forward and backward functions of an Autograd.Function object.

Prerequisites

  • Basic Python Knowledge: Familiarity with Python syntax, data structures (like lists, dictionaries), and control flow (such as loops and conditionals).
  • Foundational Knowledge of PyTorch: Understanding of core PyTorch concepts such as tensors, autograd, and basic operations.
  • Understanding of Neural Networks: Basic understanding of neural network architecture, layers, and backpropagation.
  • Familiarity with Training Workflows: Experience in setting up a simple training loop, including model definition, loss computation, and optimization in PyTorch.

These prerequisites will help you understand and effectively utilize hooks in PyTorch for debugging, modifying, or extracting information during model training.

Hooks for Tensors

A hook is basically a function with a very specific signature. When we say a hook is executed, in reality, we are talking about this function being executed.

For tensors, the signature for backward hook is,

hook(grad) -> Tensor or None

There is no forward hook for a tensor.

grad is basically the value contained in the grad attribute of the tensor after backward is called. The function is not supposed to modify its argument. It must either return None or a Tensor, which will be used in place of grad for further gradient computation. We provide an example below.

import torch 
a = torch.ones(5)
a.requires_grad = True

b = 2*a

b.retain_grad()   # Since b is non-leaf and it's grad will be destroyed otherwise.

c = b.mean()

c.backward()

print(a.grad, b.grad)

# Redo the experiment but with a hook that multiplies b's grad by 2. 
a = torch.ones(5)

a.requires_grad = True

b = 2*a

b.retain_grad()

b.register_hook(lambda x: print(x))  

b.mean().backward() 


print(a.grad, b.grad)

There are several uses of functionality, as above.

  1. You can print the value of the gradient for debugging. You can also log them. This is especially useful with non-leaf variables whose gradients are freed up unless you call retain_grad upon them. Doing the latter can lead to increased memory retention. Hooks provide a much cleaner way to aggregate these values.
  2. You can modify gradients during the backward pass. This is very important. While you can still access the grad variable of a tensor in a network, you can only access it after the entire backward pass has been done. For example, let us consider what we did above. We multiplied b’s gradient by 2, and now the subsequent gradient calculations, like those of a (or any tensor that will depend upon b for gradient), use the 2 * grad(b) instead of grad(b). In contrast, had we individually updated the parameters after the backward, we’d have to multiply b.grad as well as a.grad (or, in fact, all tensors that depend on b for gradient) by 2.
a = torch.ones(5)

a.requires_grad = True
b = 2*a

b.retain_grad()


b.mean().backward() 


print(a.grad, b.grad)

b.grad *= 2

print(a.grad, b.grad)       # a's gradient needs to updated manually

Hooks for nn.Module objects

For nn.Module object, the signature for the hook function,

hook(module, grad_input, grad_output) -> Tensor or None

for the backward hook, and

hook(module, input, output) -> None

for the forward hook.

Before we begin, let me make it clear that I’m not a fan of using hooks on nn.Module objects. First, because they force us to break abstraction. A nn.Module is supposed to be a modularised object representing a layer. However, a hook is subjected a forward and a backward, of which there can be an arbitrary number in a nn.Module object. This requires me to know the internal structure of the modularised object.

For example, an nn.Linear layer actually performs two operations during its execution: matrix multiplication followed by addition (y = W * x + b). In PyTorch’s internal computation graph, each of these operations is treated as a separate node. This is why, when you register a forward hook on a module like nn.Linear, the inputs passed to your hook function can appear as a tuple — representing the inputs to the different operations involved — rather than just a single tensor. Similarly, the output given to your hook corresponds to the result of the particular forward operation at that point.

It’s important to understand this because, depending on where you attach your hook (at the module level vs operation level), you might capture the entire module’s behavior or just part of it. Forward hooks help you monitor the flow of activations through the model, while backward hooks allow you to inspect or modify the gradients during backpropagation. Keeping this in mind ensures you interpret the inputs and outputs in your hook functions correctly.

grad_input is the gradient of the input of nn.Module object with respect to the loss (dL/dx, dL / dw, dL / b). grad_output is the gradient of the output of the nn.Module object w.r.t to the gradient. These can be pretty ambiguous for the reason of multiple calls inside a nn.Module object.

Consider the following code.

import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    return self.fc1(self.flatten(x))
  

net = myNet()

def hook_fn(m, i, o):
  print(m)
  print("------------Input Grad------------")

  for grad in i:
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")

  print("------------Output Grad------------")
  for grad in o:  
    try:
      print(grad.shape)
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")
net.conv.register_backward_hook(hook_fn)
net.fc1.register_backward_hook(hook_fn)
inp = torch.randn(1,3,8,8)
out = net(inp)

(1 - out.mean()).backward()

The output produced is.

Linear(in_features=160, out_features=5, bias=True)
------------Input Grad------------
torch.Size([5])
torch.Size([5])
------------Output Grad------------
torch.Size([5])


Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
------------Input Grad------------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
------------Output Grad------------
torch.Size([1, 10, 4, 4])

In the code above, I use a hook to print the shapes of grad_input and grad_output.

In Conv2d, you can often make educated guesses about the gradient shapes. For example, the grad_input of size [10, 3, 3, 2] corresponds to the gradient of the weights, while the grad_input of size [10] likely corresponds to the gradient of the bias. But what about the gradient of the input feature maps? Is there none?
Additionally, Conv2d typically uses techniques like im2col (or similar) to flatten an image so that the convolution can be performed as a matrix multiplication, avoiding the need for loops. So, were there backward calls involved in that process? To compute the gradient with respect to the input x, you’d need to use the grad_output from the layer behind it.

Now, when it comes to Linear, things get trickier. For example, both grad_inputs might have a size of [5], but shouldn’t the weight matrix of the linear layer be of size [160, 5]? This kind of confusion can make it hard to reason about gradients correctly.

This is why I’m not particularly fond of using hooks with nn.Modules. While they can work for simpler layers like ReLU, they can quickly become confusing and difficult to manage with more complex layers like Conv2d or Linear.

Proper Way of Using Hooks: An Opinion

So, I’m all up for using hooks on Tensors. Using named_parameters functions, I’ve been successfully been able to accomplish all my gradient modifying / clipping needs using PyTorch. named_parameters allows us much much more control over which gradients to tinker with. Let’s just say, I wanna do two things.

  1. Turn gradients of linear biases into zero while backpropagating.
  2. Ensure that no gradient going to the conv layer is less than 0.
import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    x.register_hook(lambda grad : torch.clamp(grad, min = 0))     #No gradient shall be backpropagated 
                                                                  #conv outside less than 0
      
    # print whether there is any negative grad
    x.register_hook(lambda grad: print("Gradients less than zero:", bool((grad < 0).any())))  
    return self.fc1(self.flatten(x))
  

net = myNet()

for name, param in net.named_parameters():
  # if the param is from a linear and is a bias
  if "fc" in name and "bias" in name:
    param.register_hook(lambda grad: torch.zeros(grad.shape))


out = net(torch.randn(1,3,8,8)) 

(1 - out).mean().backward()

print("The biases are", net.fc1.bias.grad)     #bias grads are zero

The output produced is:

Gradients less than zero: False
The biases are tensor([0., 0., 0., 0., 0.])

The Forward Hook for Visualising Activations

If you noticed, the Tensor doesn’t have a forward hook, while nn.Module has one, which is executed when a forward is called. Notwithstanding the issues I already highlighted with attaching hooks to PyTorch, I’ve seen many people use forward hooks to save intermediate feature maps by saving the feature maps to a python variable external to the hook function. Something like this.

visualisation = {}

inp = torch.randn(1,3,8,8)

def hook_fn(m, i, o):
  visualisation[m] = o 
  
net = myNet()

for name, layer in net._modules.items():
  layer.register_forward_hook(hook_fn)
  
out = net(inp) 

Generally, the output for an nn.Module is the output of the last forward. However, the above functionality can be safely replicated without the use of hooks. Simply append the intermediate outputs in the forward function of nn.Module object to a list. However, it might be a bit problematic to print the intermediate activation of modules inside nn.Sequential. To get past this, we need to register a hook to the children modules of the Sequential but not to Sequential itself.

import torch 
import torch.nn as nn

class myNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3,10,2, stride = 2)
    self.relu = nn.ReLU()
    self.flatten = lambda x: x.view(-1)
    self.fc1 = nn.Linear(160,5)
    self.seq = nn.Sequential(nn.Linear(5,3), nn.Linear(3,2))
    
   
  
  def forward(self, x):
    x = self.relu(self.conv(x))
    x = self.fc1(self.flatten(x))
    x = self.seq(x)
  

net = myNet()
visualisation = {}

def hook_fn(m, i, o):
  visualisation[m] = o 

def get_all_layers(net):
  for name, layer in net._modules.items():
    #If it is a sequential, don't register a hook on it
    # but recursively register hook on all it's module children
    if isinstance(layer, nn.Sequential):
      get_all_layers(layer)
    else:
      # it's a non sequential. Register a hook
      layer.register_forward_hook(hook_fn)

get_all_layers(net)

  
out = net(torch.randn(1,3,8,8))

# Just to check whether we got all layers
visualisation.keys()      #output includes sequential layers

Finally, you can turn these tensors into numpy arrays and plot activations.

FAQ

What are hooks in PyTorch, and why are they useful?
Hooks are functions that can be registered to a tensor or nn.Module to execute during the forward or backward pass. They are useful for debugging, monitoring activations, visualizing intermediate outputs, and modifying gradients during training.

What’s the difference between a forward hook and a backward hook in PyTorch?
Forward hooks execute during the forward pass and provide access to input/output tensors. Backward hooks execute during the backward pass and allow you to inspect or modify gradients.

Can you attach hooks to both Tensor and nn.Module objects in PyTorch?
Yes, hooks can be attached to both Tensors (for monitoring gradients) and nn.Module objects (for monitoring layer activations or gradients).

How do you register a backward hook on a tensor?
You register a backward hook using tensor.register_hook(), which takes a function to execute when gradients are computed for that tensor.

When should I use retain_grad() on a non-leaf tensor?
You should call retain_grad() on non-leaf tensors if you need to retain their gradients for debugging or visualization, as PyTorch does not store gradients for non-leaf tensors by default.

Can I modify gradients during backpropagation using hooks?
Yes, backward hooks allow you to modify gradients during backpropagation, which can be useful for gradient clipping or custom optimization.

What’s the proper use case for tensor hooks vs module hooks?
Tensor hooks are useful for modifying or debugging gradients for individual tensors, while module hooks are better for inspecting or modifying the outputs of entire layers or modules.

Conclusion

In this tutorial, we’ve explored the feature of hooks in PyTorch, which allows you to gain deeper insights into the internals of your model during both the forward and backward passes. Hooks provide an invaluable tool for debugging, visualizing activations, inspecting gradients, and even modifying backpropagation to suit your specific needs. While using hooks with simple layers like ReLU is straightforward, more complex modules such as nn.Linear or nn.Conv2d can lead to confusion, especially when dealing with multiple operations and forward calls. By understanding the fundamentals of how hooks work and carefully selecting where to apply them, you can gain fine-grained control over your model’s training process.

As you continue to experiment with deep learning models in PyTorch, you may find the cloud-based GPU instances from DigitalOcean to be an excellent choice for scaling your computations. With DigitalOcean’s GPU Droplets, you can accelerate training and debugging tasks, including those that require intensive visualization and gradient analysis, without the need to invest in expensive hardware. DigitalOcean offers an easy-to-use platform, ideal for deep learning practitioners, researchers, and engineers who want to experiment with state-of-the-art AI models in a flexible and cost-effective environment.

This wraps up our discussion on PyTorch hooks. We hope you now have a solid understanding of how to effectively utilize them in your deep learning workflows.

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

About the author(s)

Category:
Tutorial

Still looking for an answer?

Ask a questionSearch for more help

Was this helpful?
 
Leave a comment
Leave a comment...

This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

Become a contributor for community

Get paid to write technical tutorials and select a tech-focused charity to receive a matching donation.

DigitalOcean Documentation

Full documentation for every DigitalOcean product.

Resources for startups and SMBs

The Wave has everything you need to know about building a business, from raising funding to marketing your product.

Get our newsletter

Stay up to date by signing up for DigitalOcean’s Infrastructure as a Newsletter.

New accounts only. By submitting your email you agree to our Privacy Policy

The developer cloud

Scale up as you grow — whether you're running one virtual machine or ten thousand.

Get started for free

Sign up and get $200 in credit for your first 60 days with DigitalOcean.*

*This promotional offer applies to new accounts only.