[P] torchdata: Implement map, cache, filter etc. within PyTorch’s Datasets (like Tensorflow’s tf.data and more)
What is torchdata
I would like to present you a new open source PyTorch based project (torchdata) which extends capabilities of torch.utils.data.Dataset
by bringing map
, cache
and other operations known from tensorflow.data.Dataset
(and actually a little more than that).
All that with a single line of code: super().__init__()
For more, check documentation or github repository.
Functionalities Overview
- Use
map
,apply
,reduce
orfilter
cache
data in RAM or on disk (even partial caching, say first20%
RAM and the rest on disk)- Full PyTorch’s
Dataset
andIterableDataset
support (includingtorchvision
) - General
torchdata.maps
likeFlatten
orSelect
- Concrete
torchdata.datasets
designed for file reading and other general tasks
Example
-
Create image reading dataset
import torchdata import torchvision class Images(torchdata.Dataset): # Different inheritance def __init__(self, path: str): super().__init__() # This is the only change self.files = [file for file in pathlib.Path(path).glob("*")] def __getitem__(self, index): return Image.open(self.files[index]) def __len__(self): return len(self.files)
-
map
each element totorch.Tensor
andcache()
everything in memory:images = Images("./data").map(torchvision.transforms.ToTensor()).cache()
-
concatenate with labels (another
torchdata.Dataset
instance) and iterate over:for data, label in images | labels: # Do whatever you want with your data
Installation
pip
is the easiest of course:
pip install torchdata
You can also use nightly
releases (torchdata-nightly
) or GPU/CPU Docker based images (check documentation). Hopefully conda
will be released soon as well, stay tuned
BTW. You can also checkout torchfunc, I plan to make a separate post about that in a week or so.
Thanks for checking the above, any input would be welcome (either here or on github)
submitted by /u/szymonmaszke
[link] [comments]