大家好,欢迎来到IT知识分享网。
DataLoader
| Sampler
| DataSet
关系
Sampler
: 提供数据集中元素的索引DataSet
: 根据Sampler
提供的索引来检索数据DataLoader
: 批量加载数据用于后续的训练和测试
Sampler
class Sampler(object):
r"""Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
PyTorch官网已经实现了多种 Sampler
:
SequentialSampler
若
shuffle=False
,且未指定sampler
,默认使用
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
RandomSampler
若
shuffle=True
,且未指定sampler
,默认使用
class RandomSampler(Sampler):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify ``num_samples`` to draw. Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. """
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
n = len(self.data_source)
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
BatchSampler
like
sampler
, but returns a batch of indices at a time. Mutually exclusive withbatch_size
,shuffle
,sampler
, anddrop_last
- 在
DataLoader
中设置batch_sampler=batch_sampler
的时候,上面四个参数都必须是默认值。也很好理解,每次采样返回一个batch,那么batch_size
肯定为1
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
- 可以看到在构造
BatchSampler
实例的时候,需要传入一个sampler作为实参
最佳实践
最近看到一篇推文,分享了一个使模型训练速度提升20%的Trick–BlockShuffle 。fork了原作者的代码,并自定义了 batch_sampler
,源码见:TransformersWsz/BlockShuffleTest
参考自:
- 一个使模型训练速度提升20%的Trick–BlockShuffle
- Pytorch DataLoader详解
- torch.utils.data — PyTorch 1.10.1 documentation
- pytorch中用Mnist数据集dataloader 自定义batchsampler – 代码先锋网 (codeleading.com)
- pytorch 实现一个自定义的dataloader,每个batch都可以实现类别数量均衡 (tqwba.com)
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/14202.html