# TODO random patch augs
# TODO disable grad for x? does that speed up training?
import torch
from torchvision import transforms as v2
[docs]
class ApplyPatch(torch.nn.Module):
"""Module to apply a patch to an image."""
def __init__(
self,
scale_range: tuple[float, float] = (0.1, 0.4),
location_range: tuple[float, float] = (0.0, 1.0),
patch_crop_range: tuple[float, float] = (0.8, 1.0),
rotation_probs: tuple[float, float, float, float] = (0.25, 0.25, 0.25, 0.25),
flip_probability: float = 0.5,
):
"""
Initializes the transformation utility with configurable ranges and distributions for scaling, location, cropping, rotation, flipping, and color jitter.
Args:
scale_range (tuple[float, float], optional): Range for scaling patches. Defaults to (0.1, 0.4).
location_range (tuple[float, float], optional): Range for selecting patch locations. Defaults to (0.0, 1.0).
patch_crop_range (tuple[float, float], optional): Range for cropping patches. Defaults to (0.8, 1.0).
rotation_probs (tuple[float, float, float, float], optional): Probabilities for selecting rotation angles. Defaults to (0.25, 0.25, 0.25, 0.25).
flip_probability (float, optional): Probability of flipping the patch. Defaults to 0.5.
Attributes:
scale_range (tuple[float, float]): Range for scaling patches.
location_range (tuple[float, float]): Range for selecting patch locations.
location_distribution (torch.distributions.uniform.Uniform): Uniform distribution for patch location.
patch_scale_distribution (torch.distributions.uniform.Uniform): Uniform distribution for patch scale.
patch_crop_range (tuple[float, float]): Range for cropping patches.
patch_crop_distribution (torch.distributions.uniform.Uniform): Uniform distribution for patch cropping.
input_dims (tuple[int]): Input dimensions, default is (2,).
rotation_distribution (torch.distributions.categorical.Categorical): Categorical distribution for rotation.
flip_distribution (torch.distributions.bernoulli.Bernoulli): Bernoulli distribution for flipping.
color_jitter (v2.ColorJitter): Color jitter transformation for brightness, contrast, saturation, and hue.
"""
super().__init__()
self.scale_range = scale_range
# TODO adjust start end location with patch scale range
# self.start_distribution = torch.distributions.half_normal.HalfNormal(
# loc=location_range[0], scale=(location_range[1] - location_range[0]) / 2
# )
# TODO change to half normal distribution
self.location_range = location_range
self.location_distribution = torch.distributions.uniform.Uniform(low=location_range[0], high=location_range[1])
# TODO change to half normal distribution
self.patch_scale_distribution = torch.distributions.uniform.Uniform(low=scale_range[0], high=scale_range[1])
self.patch_crop_range = patch_crop_range
self.patch_crop_distribution = torch.distributions.uniform.Uniform(
low=patch_crop_range[0], high=patch_crop_range[1]
)
self.input_dims = (2,) # NOTE could update to handle different shapes than images
# self.rotation_distribution = torch.distributions.uniform.Uniform(
# low=rotation_probs[0], high=rotation_probs[1]
# )
# ic(rotation_probs)
self.rotation_distribution = torch.distributions.categorical.Categorical(probs=torch.tensor(rotation_probs))
self.flip_distribution = torch.distributions.bernoulli.Bernoulli(probs=flip_probability)
# TODO use color jitter
self.color_jitter = v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
# TODO switch to crop then resize? use resized_crop?
[docs]
def forward(self, x: torch.Tensor, patch: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor:
"""
Forward method.
The patch is randomly rotated, resized, and placed at a location determined by a distribution.
The function ensures the patch fits within the image boundaries and updates the target tensor
`y` if provided.
Args:
x (torch.Tensor): The input image tensor of shape (..., H, W).
patch (torch.Tensor): The patch tensor to be inserted, typically of shape (..., h, w).
y (torch.Tensor, optional): Target tensor containing annotations (e.g., bounding boxes and labels).
Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- The transformed image tensor with the patch inserted.
- The updated target tensor (if provided), otherwise None.
"""
x_copy = x.clone()
# NOTE do the rotation before computing and using the sizes
patch = torch.rot90(
patch,
k=self.rotation_distribution.sample().item(),
dims=(-2, -1), # Rotate around the height and width dimensions
)
patch_scale = self.patch_scale_distribution.sample(self.input_dims)
# TODO scale patch to a random size maybe? or keep to image ratio?
# size = torch.round(torch.tensor(patch.shape[1:]) * patch_scale).to(torch.int32).tolist()
scaled_shape = torch.tensor(x_copy.shape[-2:]) * patch_scale
rounded_shape = torch.round(scaled_shape)
rounded_size = rounded_shape.to(torch.int32)
size = rounded_size
# size = rounded_size.tolist()
# size = torch.round(torch.tensor(x_copy.shape[1:]) * patch_scale).to(torch.int32).tolist()
# patch = F.resize(
# patch,
# size=size,
# )
# TODO switch to functional resize to see if it fixes vmap
resized_crop = v2.RandomResizedCrop(
size=size,
# scale=(self.scale_range[0], self.scale_range[1]),
# ratio=self.patch_crop_range,
)
patch = resized_crop(patch)
# pad_top = torch.ceil(patch.shape[1] / 2).to(torch.int32).item()
# pad_bottom = torch.floor(patch.shape[1] / 2).to(torch.int32).item()
# pad_left = torch.ceil(patch.shape[2] / 2).to(torch.int32).item()
# pad out the image to allow patch to be placed at edges of image
# left_right_pad = torch.tensor(patch.shape[1])
# top_bottom_pad = torch.tensor(patch.shape[2])
# left_pad = torch.ceil(left_right_pad).to(torch.int32).item()
# top_pad = torch.ceil(top_bottom_pad).to(torch.int32).item()
# # right_pad = 0
# # bottom_pad = 0
# right_pad = torch.floor(left_right_pad).to(torch.int32).item()
# bottom_pad = torch.floor(top_bottom_pad).to(torch.int32).item()
# top_bottom_pad = patch.shape[1]
# # TODO if i'm already cropping the patch, does it make sense to pad the image for cropping?
# left_right_pad = patch.shape[2]
# x_copy = F.pad(
# x_copy,
# # padding=(left_pad, right_pad, top_pad, bottom_pad),
# # padding=(x.shape[2], x.shape[2], x.shape[1], x.shape[1]),
# # padding=(left_pad, top_pad, right_pad, bottom_pad),
# padding=(left_right_pad, top_bottom_pad)
# )
# ic((x_copy.shape[1] - x.shape[1]) / 2)
# ic((x_copy.shape[2] - x.shape[2]) / 2)
# assert x_copy.shape[1] - (patch.shape[1]*2) == x.shape[1], "something is off"
# assert x_copy.shape[2] - (patch.shape[2]*2) == x.shape[2], "something is off"
# ic("post padding x_copy.shape", x_copy.shape)
# TODO update to be between like -1, 2 so the patch can start outside the image
location_scale = self.location_distribution.sample(self.input_dims)
# x_1, y_1 = torch.round(torch.tensor(x.shape[1:]) * location_scale).to(torch.int32)
# TODO break up patch transforms to other transforms for more flexibility
# for example make each transform not always used
# the location doesn't make sense to put the patch at the right/bottom padded area
# so we need to adjust the location
# NOTE the patch is already cropped so we shouldn't worry about handling placing the patch off the edges
# max_size = torch.tensor(x_copy.shape[-2:]) - torch.tensor(patch.shape[-2:])
max_size = torch.tensor(x_copy.shape[-2:]) - torch.tensor(patch.shape[-2:])
# xy_1 = max_size * location_scale
x_1, y_1 = torch.round(max_size * location_scale).to(torch.int32)
# patch_crop_scale = self.patch_crop_distribution.sample(self.input_dims)
# patch_crop_x = torch.round(patch.shape[1] * patch_crop_scale[0]).to(torch.int32)
# patch_crop_y = torch.round(patch.shape[2] * patch_crop_scale[1]).to(torch.int32)
# left = patch_crop_y
# top = patch_crop_x
# height = min(patch.shape[1] - patch_crop_x, x_copy.shape[1] - x_1)
# width = min(patch.shape[2] - patch_crop_y, x_copy.shape[2] - y_1)
# patch = F.crop(
# patch,
# top=top,
# left=left,
# height=height,
# width=width,
# )
# TODO update to take any rotation
# patch = F.rotate(
# patch,
# angle=self.rotation_distribution.sample(self.sample_size).item(),
# expand=True, # Expand the image to fit the rotated patch
# )
# handle patch going off the edges of the image
x_2 = x_1 + patch.shape[-2]
y_2 = y_1 + patch.shape[-1]
if x_2 > x_copy.shape[-2]:
raise ValueError(f"Patch exceeds image width: {x_2} > {x_copy.shape[-2]}")
if y_2 > x_copy.shape[-1]:
raise ValueError(f"Patch exceeds image height: {y_2} > {x_copy.shape[-1]}")
# patch_x_1 = max(0, x_1)
# height = y_2-y_1
# width = x_2-x_1
x_copy[..., x_1:x_2, y_1:y_2] = patch
# crop back down
# x_copy = x_copy[:, top_bottom_pad:-top_bottom_pad, left_right_pad:-left_right_pad]
# x_copy = x_copy[:, top_bottom_pad:-top_bottom_pad, left_right_pad:-left_right_pad]
# x_copy = x_copy[:, left_pad:-right_pad, top_pad:-bottom_pad]
# filter target boxes and labels
# y_copy = y.copy() if y is not None else None
# TODO adjust y?
y_copy = y
# if y_copy is not None:
# if "boxes" in y:
# # Adjust boxes to account for the patch location
# y_copy["boxes"][:, 0] = torch.clamp(y_copy["boxes"][:, 0] + x_1, min=0)
# y_copy["boxes"][:, 1] = torch.clamp(y_copy["boxes"][:, 1] + y_1, min=0)
# y_copy["boxes"][:, 2] = torch.clamp(y_copy["boxes"][:, 2] + x_1, max=x_copy.shape[1])
# y_copy["boxes"][:, 3] = torch.clamp(y_copy["boxes"][:, 3] + y_1, max=x_copy.shape[2])
# else:
# y_copy["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
# if "labels" not in y_copy:
# y_copy["labels"] = torch.zeros((0,), dtype=torch.int64)
return x_copy, y_copy