Custom Datasets, Aug#
Generally, a custom Dataset class may look like this:
from torch.utils.data import Dataset
from antelope import ml
class MyDataset(Dataset):
def __init__(self, train=True):
self.classes = ['dog', 'cat']
...
def __getitem__(self, index):
...
return obs, label
def __len__(self):
...
ml(dataset=MyDataset)
For more info, see Pytorch’s tutorial on map-style Datasets.
Any Pytorch Dataset will do, with one caveat: the __getitem__ should output either an (obs, label) pair OR a dictionary of “datums” e.g. {'obs': obs, 'label': label}.
Classification
Since the default task is task=classify, the above script will learn
to classify MyDataset.
If you define your own classify Dataset, include a .classes
attribute listing the classes in your Dataset. Otherwise, The Antelope will
automatically count unique classes, which may be different across
training and test sets.
Test datasets
You can include a train= boolean arg to your custom Dataset to
define different behaviors for training and testing, or use a different
custom test Dataset via test_dataset=.
Transforms & augmentations
Pre-compile:
from torchvision.transforms import Resize
ml(dataset={'_target_': MyDataset, 'transform': Resize([302,170])})
Transform on CPU at runtime:
ml(dataset=MyDataset, transform=Resize([302,170]))
Transform on GPU at runtime:
from antelope.Agents.Blocks.Augmentations import RandomShiftsAug
ml(dataset=MyDataset, aug=RandomShiftsAug)
All passed-in Datasets will support the dataset.transform= argument.
dataset.transform= is distinct from transform= and aug=, as
transform= runs a transform on CPU at runtime and aug= runs a
batch-vectorized augmentation on GPU at runtime, whereas
dataset.transform= transforms/pre-compiles the dataset before
training begins. One-time operations like Resize are most efficient
here.
There are also two additional kinds of transform/augmentation for nuanced cases. env.transform= can transform an online stream from a rollout at runtime and dataset.aug= can pre-compile a dataset with a batch-vectorized augmentation on GPU. Pretty much every transform/aug need is met in whatever data processing pipeline you’re trying to implement.
Standardization & normalization
Stats will automatically be computed for standardization and
normalization, and saved in the corresponding Memory card.yaml in
World/ReplayBuffer. Disable standardization with
standardize=false. This will trigger to use normalization instead.
Disable both with standardize=false norm=false. You may learn more
about the differences at
GeeksforGeeks.
By default, an agent loaded from a checkpoint will reuse its original
tabulated stats of the data that it was trained on even when evaluated
or further trained on a new dataset, to keep conditions consistent.
Subsets
Sub-classing is possible with the dataset.subset='[0, 5, 2]'
keyword. In this example, only classes 0, 5, and 2 of the
given Dataset will be used for training and evaluation.
ml(dataset={'_target_': MyDataset, 'subset': [0, 5, 2]})
Built-In Datasets#
All TorchVision datasets are supported by default and can be passed in by name (e.g. dataset=MNIST) as well as TinyImageNet, which is provided as an example custom dataset.
For an iterative-style Dataset, use an Environment.