Understanding the Difference Between ‘self’ and ‘ctx’ in PyTorch

Differences Between Ctx And Self

The ‘self’ is a keyword used in object-oriented programming to refer to the instance of a class. It acts as a placeholder for the instance itself, providing access to the attributes and methods defined within the class. When a method is called on an instance, ‘self’ allows the method to access and manipulate the instance’s data.

Let’s look at both self and ctx in further detail.

Recommended: A Quick Guide to Pytorch Loss Functions

The Role of ‘self’ in Object-Oriented Programming

Self acts as a placeholder for the instance of the class itself while providing itself access to attributes and methods defined in the class. Let us understand it further with Python code.

import torch

class LinearModule(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super(LinearModule, self).__init__()
    self.linear = torch.nn.Linear(in_features, out_features)

  def forward(self, x):
    # Accessing the `linear` attribute (part of the class instance)
    return self.linear(x)

# Creating an instance of the LinearModule class
model = LinearModule(5, 3)

# Calling the forward method using the instance
output = model.forward(torch.randn(2, 5))

print(output)

In the code above, the self is a specific instance of the LinearModule class. When the forward method is called, the self has access to the linear attribute, which is the torch module that allows the method to perform different operations. Let us look at the code below.

Pytorch Self Output
Pytorch Self Output

Now, let us move on to ctx and understand its functions.

The Purpose of ‘ctx’ in Custom Autograd Functions

‘ctx’ is an abbreviation for ‘context’ and is used in defining custom autograd functions using the torch module. An autograd function essentially has decision branches and loops whose lengths are unknown until runtime which will be traced to get gradients for learning.

It is a context object that holds tensors and other information throughout the function’s execution. This context object is essentially crucial for the backward pass in calculating gradients during differentiation.

import torch

class CustomFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # Save the input for the backward pass
        ctx.save_for_backward(input)
        # Perform the forward pass operation
        output = input * 2
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the saved input from the forward pass
        input, = ctx.saved_tensors
        # Calculate the gradient of the input with respect to the output
        grad_input = grad_output * 2
        return grad_input

# Create a tensor with requires_grad=True
x = torch.tensor([1.0], requires_grad=True)

# Apply the custom function during the forward pass
custom_function = CustomFunction.apply
y = custom_function(x)

# Perform some additional operations
z = y * 3

# Compute gradients
z.backward()

# Access the gradient of the input tensor
print(x.grad)

In the above code, we have created a custom autograd function using Pytorch and ‘ctx’. Essentially, custom autograd functions allow us to define operations and gradients. Let us look at the output.

PyTorch Ctx Output
PyTorch Ctx Output

In the example, CustomFunction is a customized autograd function that does a simple operation during the forward pass. The backward method helps us calculate the input’s gradient to output.

‘self’ vs ‘ctx’

‘self’‘ctx’
– Used to store parameters that need to be accessed during both forward and backward passes<br>- Used for constants and weights required throughout the custom function’s lifecycle<br>- Used for parameters that remain the same across multiple function calls– Used as the first argument in the forward and backward methods of a custom autograd function<br>- Acts as a context object to store information during the forward pass for later use in the backward pass<br>- Commonly used to store values or input tensors for the backward pass
Refers to the instance of a class and provides access to its attributes and methodsUsed specifically in defining custom autograd functions in PyTorch
Allows methods to access and manipulate the instance’s dataFacilitates the storage and retrieval of information between forward and backward passes in custom autograd functions

Conclusion

Here you go! Now you know what ctx and self is in PyTorch. In this article, we briefly touched upon ctx and self in PyTorch and learned how to implement them in Python. We also learned their differences and how they are used in backward and forward passes in

Hope you enjoyed reading it!!

Recommended: Converting Between Pytorch Tensors and Numpy Arrays in Python

Recommended: What Are the Pre-trained Models Available in PyTorch?