Dataset&DataLoader

Python,Torch,Daily life,Share 2024-03-14 84 次浏览 次点赞

dataset

需要实现

def __getitem__(self, index) -> T_co:  # 基于一个索引返回一个训练样本(x, y)构成的训练对
    raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
 
def __len__(self):  # 将数据放进去,返回数据总大小   return len(self.img_labels)
    return self.tensors[0].size(0)
import os
import pandas as pd
from torchvision.io import read_image

# 从磁盘中读取训练数据、__getitem__中能够根据idx返回相对应的单个训练样本
class CustomImageDataset(Dataset):
    # init中传入文件路径,要让它知道数据是保存到哪里的
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)  # 读取标注文件,得到标注label
        self.img_dir = img_dir  # 图片目录
        self.transform = transform  # 要应用什么样的图片变换
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)  # 返回数据的大小
    
    # 根据idx,返回一个样本
    def __getitem__(self, idx):
        # 该图片的文件名是根据self.img_labels的第idx行、第0列决定
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # 先找到图片的路径,根目录配合idx
        image = read_image(img_path)  # 将图片加载到内存中
        label = self.img_labels.iloc[idx, 1]  # idx行第1列存储了标签
        if self.transform:  # 对该img进行一些变预处理:缩放、裁剪、归一化等
            image = self.transform(image)
        if self.target_transform:  # 对label进行处理:变换为one-hot编码等
            label = self.target_transform(label)
        return image, label

DataLoader

将单个样本组织成一个批次,用于神经网络的训练

Preparing your data for training with DataLoaders

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API.

一次性加载minbatch的训练样本,并且每次加载时打乱顺序

    dataset: Dataset[T_co]  # dataset的一个实例化对象
    batch_size: Optional[int]  # 默认1
    shuffle: Optional[bool] = None  # 在每个训练周期结束后,是否需要对数据进行打乱
    sampler: Union[Sampler, Iterable, None] = None  # 怎样对数据进行采样
    num_workers: int  # 默认为0,使用主进程加载数据
    pin_memory: bool  # 把tensor保存在GPU中,不需要每次进行重复的保存
    drop_last: bool  # 如果数据不是batch_size整数倍时,True则是将最后一个小批次的数据丢掉
    collate_fn: Optional[_collate_fn_t] = None  # 对sampler采样的小批次数据后处理input:batch out:batch;每个样本标签的数量不一致,则需要扩充
    timeout: float
    sampler: Union[Sampler, Iterable]
    pin_memory_device: str
    prefetch_factor: Optional[int]
    _iterator : Optional['_BaseDataLoaderIter']
    __initialized = False
    # 主要做了3件事: 构建单样本的sampler、构建单样本组合的batch_sampler、构建collate
    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
                 batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: Optional[int] = None,
                 persistent_workers: bool = False,
                 pin_memory_device: str = ""):
        torch._C._log_api_usage_once("python.data_loader")
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)  # test时,shuffle一般不需要设置

Iterate through the DataLoader

We have loaded that dataset into the DataLoader and can iterate through the dataset as needed. Each iteration below returns a batch of train_features and train_labels (containing batch_size=64 features and labels respectively). Because we specified shuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers).

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")  # Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Labels batch shape: {train_labels.size()}")  # Labels batch shape: torch.Size([64])
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")  Label: 5
# 样本级别的采样
if sampler is not None and shuffle:  # 如果自定义了samlper则不需要shuffle
   raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')
# batch级别的采样,如果设置了batch_samlper,则不需要设置batch_size、shuffle、sampler、drop_last
if batch_sampler is not None:
   # auto_collation with custom batch_sampler
   if batch_size != 1 or shuffle or sampler is not None or drop_last:
         raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
# 以某种顺序从dataset中取样本
if sampler is None:  # give default samplers
   if self._dataset_kind == _DatasetKind.Iterable:
      # See NOTE [ Custom Samplers and IterableDataset ]
      sampler = _InfiniteConstantSampler()
      else:  # map-style
          if shuffle:  # True,则使用内置的RandomSampler
             sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type] 返回随机的索引列表
          else:  # 否则使用SequentialSampler:按照dataset原本的顺序取出数据构成minibatch
             sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
                
# 默认的创建一个batch_sampler
if batch_size is not None and batch_sampler is None:
   # auto_collation without custom batch_sampler
   batch_sampler = BatchSampler(sampler, batch_size, drop_last)

RandomSampler

def __iter__(self) -> Iterator[int]:
    n = len(self.data_source)  # 获取dataset的长度,也就是数据集的大小
    if self.generator is None:  # 如果没有传入generator
        seed = int(torch.empty((), dtype=torch.int64).random_().item())  # 随机生成一个种子
        generator = torch.Generator()  # 构建torch的Generator
        generator.manual_seed(seed)  # 设置generator的种子
    else:
        generator = self.generator

    if self.replacement:
        for _ in range(self.num_samples // 32):
            yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy())
        final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator)
        yield from map(int, final_samples.numpy())
    else:
        for _ in range(self.num_samples // n):
            yield from map(int, torch.randperm(n, generator=generator).numpy())
        yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy())  # 返回0~n-1的list的一个随机组合[0, 1, 2]=>[1, 0, 2]...

SequentialSampler

def __iter__(self) -> Iterator[int]:
   return iter(range(len(self.data_source)))  # 返回有序的索引

batch_sampler从dataset中以sampler方式取出的数据,拼成一个batch然后返回idx索引。

def __iter__(self) -> Iterator[List[int]]:
    # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
    if self.drop_last:
        sampler_iter = iter(self.sampler)
        while True:
            try:
                batch = [next(sampler_iter) for _ in range(self.batch_size)]
                yield batch
            except StopIteration:
                break
    else:
        batch = [0] * self.batch_size  # 新建一个空数组
        idx_in_batch = 0
        for idx in self.sampler:  # 从sampler中取元素的idx
            batch[idx_in_batch] = idx  # 将读取到的idx添加到batch的数组中
            idx_in_batch += 1 
            if idx_in_batch == self.batch_size:  # 如果数组的长度==batch_size
                yield batch  # 返回该batch
                idx_in_batch = 0  # 索引归0
                batch = [0] * self.batch_size  # batch置空,再循环
        if idx_in_batch > 0:
            yield batch[:idx_in_batch]

collate_fn

if collate_fn is None:
    if self._auto_collation:  # 假设走这个分支
        collate_fn = _utils.collate.default_collate
    else:
        collate_fn = _utils.collate.default_convert
#############################################################
 @property
    def _auto_collation(self):  # 根据batch_sampler是否是None来设置
        return self.batch_sampler is not None

default_collate

def default_collate(batch):
    return collate(batch, collate_fn_map=default_collate_fn_map)
    
def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
    elem = batch[0]
    elem_type = type(elem)

    if collate_fn_map is not None:
        if elem_type in collate_fn_map:
            return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)

        for collate_type in collate_fn_map:
            if isinstance(elem, collate_type):
                return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)

    if isinstance(elem, collections.abc.Mapping):
        try:
            return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
        except TypeError:
            # The mapping type may not support `__init__(iterable)`.
            return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

        if isinstance(elem, tuple):
            return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
        else:
            try:
                return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
            except TypeError:
                # The sequence type may not support `__init__(iterable)` (e.g., `range`).
                # 将元素重新组合起来,返回batch
                return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

train_dataloader为何能成为迭代器?

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))  # 为何能成为迭代器?走的下方__iter__逻辑
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

dataloader里

def _get_iterator(self) -> '_BaseDataLoaderIter':
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        self.check_worker_number_rationality()
        return _MultiProcessingDataLoaderIter(self)
      
def __iter__(self) -> '_BaseDataLoaderIter':  # 实现了该函数,类的实例化对象就可以在前边加一个iter函数,将其变为迭代器
    if self.persistent_workers and self.num_workers > 0:
        if self._iterator is None:
            self._iterator = self._get_iterator()
        else:
            self._iterator._reset(self)
        return self._iterator
    else:
        return self._get_iterator()

预知iter(train_dataloader)是干啥的?

if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super().__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
        #   Taking care of distributed sharding
        if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
            # For BC, use default SHARDING_PRIORITIES
            torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
        # 创建fetcher,可以从dataset中取值
        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # 得到索引
        data = self._dataset_fetcher.fetch(index)  # fetch索引,得到数据
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
        return data  # 最终返回数据
class _BaseDataLoaderIter:
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._shared_seed = None
        self._pg = None
        if isinstance(self._dataset, IterDataPipe):
            if dist.is_available() and dist.is_initialized():
                self._pg = dist.new_group(backend="gloo")
            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
            shared_rng = torch.Generator()
            shared_rng.manual_seed(self._shared_seed)
            self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        ws, rank = _get_distributed_settings()
        self._world_size = ws
        self._rank = rank
        if (len(loader.pin_memory_device) == 0):
            self._pin_memory = loader.pin_memory and torch.cuda.is_available()
            self._pin_memory_device = None
        else:
            if not loader.pin_memory:
                warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
                            "please set pin_memory to true, if you need to use the device pin memory")
                warnings.warn(warn_msg)

            self._pin_memory = loader.pin_memory
            self._pin_memory_device = loader.pin_memory_device
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        if isinstance(self._dataset, IterDataPipe):
            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
            shared_rng = torch.Generator()
            shared_rng.manual_seed(self._shared_seed)
            self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError  # 需要子类实现

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()  # type: ignore[call-arg]
            data = self._next_data()  # 有被调用,得到data
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data  # 最终返回data

    def __len__(self) -> int:
        return len(self._index_sampler)

    def __getstate__(self):
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

所以,iter(train_dataloader)是调用的__iter__得到_BaseDataLoaderIter实例,在该实例中,又调用了__next__方法,

里边又调用了_next_data()得到data_SingleProcessDataLoaderIter内虽然没有实现__next__方法,但是其实现了_next_data(),因为在基类中,我们是调用_next_data(self)再返回到next方法中,于是就可以next(iter(train_dataloader))得到一个minibatch的数据。

@property  
def _index_sampler(self):
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

def __len__(self) -> int:
    if self._dataset_kind == _DatasetKind.Iterable:
        length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore[assignment, arg-type]
        if self.batch_size is not None:  # IterableDataset doesn't allow custom sampler or batch_sampler
            from math import ceil
            if self.drop_last:
                length = length // self.batch_size
            else:
                length = ceil(length / self.batch_size)
        return length
    else:
        return len(self._index_sampler)

返回batch_sampler,dataloader的长度就是基于该采样器的长度计算出来的。next(iter(train_dataloader))next()能调用多少次?根据dataloader的长度和采样器的大小决定的。


本文由 fmujie 创作,采用 知识共享署名 3.0,可自由转载、引用,但需署名作者且注明文章出处。

还不快抢沙发

添加新评论

召唤看板娘