[docs]classScaleApplyPatch(torch.nn.Module):""" Applies a patch to an image at a scaled size. This is useful for evaluating patch effectiveness. """def__init__(self,scale=0.25,preserve_aspect_ratio=True):""" Initializes the object with a specified scale factor. Args: scale (float, optional): The scale factor to be used. Defaults to 0.25. preserve_aspect_ratio (bool, optional): Whether to preserve the aspect ratio during scaling. Defaults to True. """super().__init__()self.scale=scaleself.preserve_aspect_ratio=preserve_aspect_ratio
[docs]defforward(self,x:torch.Tensor,patch:torch.Tensor,y:torch.Tensor=None)->torch.Tensor:""" Applies a scaled patch to the input tensor `x` and optionally adjusts target annotations `y`. Args: x (torch.Tensor): The input tensor, typically an image of shape (C, H, W). patch (torch.Tensor): The patch tensor to be applied to `x`. y (torch.Tensor, optional): Target annotations dictionary containing keys such as "boxes" and "labels". Returns: Tuple[torch.Tensor, Optional[dict]]: - Modified input tensor with the patch applied. - Modified target annotations dictionary, if provided, with bounding boxes and labels adjusted to fit the new image dimensions. Notes: - The patch is resized according to a fixed scale before being applied. - Bounding boxes in `y` are clamped to ensure they remain within the image boundaries. - If "boxes" or "labels" are missing in `y`, they are initialized as empty tensors. """x_copy=x.clone()# Scale the patch to a fixed sizeifself.preserve_aspect_ratio:# TODO implement aspect ratio preservationpasssize=torch.round(torch.tensor(x.shape[1:])*self.scale).to(torch.int32).tolist()patch=F.resize(patch,size=size)x_copy[...,:patch.shape[-2],:patch.shape[-1]]=patch# TODO pull out or find built-in for thisy_copy=y.copy()ifyisnotNoneelseNoneify_copyisnotNone:if"boxes"iny:# Adjust boxes to account for the patch locationy_copy["boxes"][:,0]=torch.clamp(y_copy["boxes"][:,0],min=0)y_copy["boxes"][:,1]=torch.clamp(y_copy["boxes"][:,1],min=0)y_copy["boxes"][:,2]=torch.clamp(y_copy["boxes"][:,2],max=x_copy.shape[1])y_copy["boxes"][:,3]=torch.clamp(y_copy["boxes"][:,3],max=x_copy.shape[2])else:y_copy["boxes"]=torch.zeros((0,4),dtype=torch.float32)if"labels"notiny_copy:y_copy["labels"]=torch.zeros((0,),dtype=torch.int64)returnx_copy,y_copy