Custom Datasets, Aug#

Generally, a custom Dataset class may look like this:

from import Dataset

from antelope import ml

class MyDataset(Dataset):
    def __init__(self, train=True):
        self.classes = ['dog', 'cat']

    def __getitem__(self, index):

        return obs, label

    def __len__(self):


For more info, see Pytorch’s tutorial on map-style Datasets.

Any Pytorch Dataset will do, with one caveat: the __getitem__ should output either an (obs, label) pair OR a dictionary of “datums” e.g. {'obs': obs, 'label': label}.


Since the default task is task=classify, the above script will learn to classify MyDataset.

If you define your own classify Dataset, include a .classes attribute listing the classes in your Dataset. Otherwise, The Antelope will automatically count unique classes, which may be different across training and test sets.

Test datasets

You can include a train= boolean arg to your custom Dataset to define different behaviors for training and testing, or use a different custom test Dataset via test_dataset=.

Transforms & augmentations


from torchvision.transforms import Resize

ml(dataset={'_target_': MyDataset, 'transform': Resize([302,170])})

Transform on CPU at runtime:

ml(dataset=MyDataset, transform=Resize([302,170]))

Transform on GPU at runtime:

from antelope.Agents.Blocks.Augmentations import RandomShiftsAug

ml(dataset=MyDataset, aug=RandomShiftsAug)

All passed-in Datasets will support the dataset.transform= argument. dataset.transform= is distinct from transform= and aug=, as transform= runs a transform on CPU at runtime and aug= runs a batch-vectorized augmentation on GPU at runtime, whereas dataset.transform= transforms/pre-compiles the dataset before training begins. One-time operations like Resize are most efficient here.

There are also two additional kinds of transform/augmentation for nuanced cases. env.transform= can transform an online stream from a rollout at runtime and dataset.aug= can pre-compile a dataset with a batch-vectorized augmentation on GPU. Pretty much every transform/aug need is met in whatever data processing pipeline you’re trying to implement.

Standardization & normalization

Stats will automatically be computed for standardization and normalization, and saved in the corresponding Memory card.yaml in World/ReplayBuffer. Disable standardization with standardize=false. This will trigger to use normalization instead. Disable both with standardize=false norm=false. You may learn more about the differences at GeeksforGeeks. By default, an agent loaded from a checkpoint will reuse its original tabulated stats of the data that it was trained on even when evaluated or further trained on a new dataset, to keep conditions consistent.


Sub-classing is possible with the dataset.subset='[0, 5, 2]' keyword. In this example, only classes 0, 5, and 2 of the given Dataset will be used for training and evaluation.

ml(dataset={'_target_': MyDataset, 'subset': [0, 5, 2]})

Built-In Datasets#

All TorchVision datasets are supported by default and can be passed in by name (e.g. dataset=MNIST) as well as TinyImageNet, which is provided as an example custom dataset.

For an iterative-style Dataset, use an Environment.