Source code for shok.utils.transforms.pass_round

import torch

from shok.utils import functions


[docs] class PassRound(torch.nn.Module): """ A custom torch.nn.Module that applies a soft rounding operation to the input tensor. Args: x (torch.Tensor): The input tensor to be rounded. y (optional): An optional secondary input, passed through unchanged. Returns: Tuple[torch.Tensor, Any]: A tuple containing the rounded tensor and the optional secondary input. Note: The actual rounding logic is implemented in `functions.PassRound.apply`. """
[docs] def forward(self, x: torch.Tensor, y=None) -> torch.Tensor: """ Applies a placeholder soft rounding operation to the input tensor. Args: x (torch.Tensor): Input tensor to be processed. y (optional): Additional input, currently unused. Returns: Tuple[torch.Tensor, Any]: A tuple containing the processed tensor and the second input (y). """ # Placeholder for soft rounding logic return functions.PassRound.apply(x), y