shok.data.datasets package

Submodules

shok.data.datasets.coco module

class shok.data.datasets.coco.CocoDataModule(batch_size: int = 2, sample_size: int = 1)[source]

Bases: LightningDataModule

CocoDataModule is a PyTorch Lightning DataModule for loading and managing the COCO 2017 dataset for object detection tasks.

This module provides functionality to: - Download and prepare the COCO 2017 validation dataset using FiftyOne’s zoo loader. - Load images and annotations using torchvision’s CocoDetection. - Apply a sequence of transforms for preprocessing, including image conversion, bounding box validation, and dtype conversion. - Split the dataset into training and validation subsets. - Construct mappings from category IDs to class names for downstream use. - Provide PyTorch DataLoaders for training, validation, and (optionally) testing.

train_dataset_repeat (Optional): Repeated training dataset (not implemented). wandb_classes (Optional): List of class names for use with Weights & Biases. fiftyone_dataset: FiftyOne dataset object loaded in prepare_data.

Methods:

__init__(batch_size: int = 2, sample_size: int = 1): Initializes the data module with batch and sample sizes. prepare_data(): Downloads and prepares the COCO 2017 validation dataset using FiftyOne. setup(stage=None): Loads and preprocesses the dataset, splits into train/val, and constructs class mappings. train_dataloader(): Returns a DataLoader for the training dataset. val_dataloader(): Returns DataLoaders for both training and validation datasets for evaluation. test_dataloader(): Not implemented; raises NotImplementedError.

  • The module is designed for use with PyTorch Lightning.

  • Data loading and splitting logic is customizable via TODOs.

  • The test dataloader is not implemented.

prepare_data()[source]

Loads the COCO 2017 validation dataset using FiftyOne’s zoo loader and assigns it to self.fiftyone_dataset.

The dataset is loaded with detection labels.

Returns:

None

setup(stage=None)[source]

Sets up the COCO dataset for training and validation.

Loads the COCO 2017 validation images and annotations using torchvision’s CocoDetection. Applies a sequence of transforms to the dataset, including image conversion, target insurance, and dtype conversion. Wraps the dataset for compatibility with transforms v2 and splits it into training and validation subsets. Also constructs a mapping from category IDs to class names.

Args:

stage (str, optional): Stage of setup (e.g., ‘fit’, ‘test’). Defaults to None.

Attributes:

base_dataset (torchvision.datasets.CocoDetection): The base COCO detection dataset. train_dataset (torch.utils.data.Dataset): Training subset of the dataset. val_dataset (torch.utils.data.Dataset): Validation subset of the dataset. idx_to_class (dict): Mapping from category IDs to class names.

test_dataloader()[source]

Creates and returns a DataLoader for the test dataset.

Returns:

torch.utils.data.DataLoader: DataLoader instance for the test dataset.

Note:

This method is not yet implemented.

train_dataloader()[source]

Creates and returns a DataLoader for the training dataset.

Returns:

torch.utils.data.DataLoader: DataLoader configured for training.

Notes:
  • Uses the training dataset (self.train_dataset).

  • Batch size is set from hyperparameters (self.hparams.batch_size).

  • Data is shuffled for training.

  • Number of worker processes is determined by available CPU cores (minimum 1, maximum 8).

  • Persistent workers and pinned memory are enabled for performance.

  • Uses a custom collate function to unpack batches.

val_dataloader()[source]

Creates and returns validation dataloaders for the training and validation datasets.

Returns:
dict: A dictionary containing two DataLoader objects:
  • “clean_train”: DataLoader for the training dataset (self.train_dataset) with validation settings.

  • “val”: DataLoader for the validation dataset (self.val_dataset).

Both DataLoaders use the following settings:
  • batch_size: Defined by self.hparams.batch_size.

  • shuffle: False (no shuffling).

  • num_workers: Number of worker processes, set to at least 1 and at most 8, based on available CPU cores.

  • persistent_workers: True (workers are kept alive between epochs).

  • pin_memory: True (enables faster data transfer to CUDA-enabled GPUs).

  • collate_fn: Function to unpack the dataset samples (tuple(zip(*x))).

shok.data.datasets.utils module

class shok.data.datasets.utils.FiftyOneTorchDataset(fiftyone_dataset, transforms=None, gt_field='ground_truth', classes=None)[source]

Bases: Dataset

A class to construct a PyTorch dataset from a FiftyOne dataset.

Args:

fiftyone_dataset: a FiftyOne dataset or view that will be used for training or testing transforms (None): a list of PyTorch transforms to apply to images and targets when loading gt_field (“ground_truth”): the name of the field in fiftyone_dataset that contains the

desired labels to load

classes (None): a list of class strings that are used to define the mapping between

class names and indices. If None, it will use all classes present in the given fiftyone_dataset.

get_classes()[source]

Returns the list of class labels associated with the dataset.

Returns:

list: A list containing the class labels.