Source code for shok.utils.transforms.scale_grad_transform
import torch
from shok.utils import functions
[docs]
class ScaleGradTransform(torch.nn.Module):
"""Transforms scales the gradient of the input tensor."""
def __init__(self):
"""Initialize the ScaleGradTransform."""
super().__init__()
[docs]
def forward(self, x, y=None):
"""Scale the gradient of the input tensor."""
return functions.ScaleGrad.apply(x), y