dataset
需要实现
def __getitem__(self, index) -> T_co: # 基于一个索引返回一个训练样本(x, y)构成的训练对
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __len__(self): # 将数据放进去,返回数据总大小 return len(self.img_labels)
return self.tensors[0].size(0)
import os
import pandas as pd
from torchvision.io import read_image
# 从磁盘中读取训练数据、__getitem__中能够根据idx返回相对应的单个训练样本
class CustomImageDataset(Dataset):
# init中传入文件路径,要让它知道数据是保存到哪里的
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file) # 读取标注文件,得到标注label
self.img_dir = img_dir # 图片目录
self.transform = transform # 要应用什么样的图片变换
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels) # 返回数据的大小
# 根据idx,返回一个样本
def __getitem__(self, idx):
# 该图片的文件名是根据self.img_labels的第idx行、第0列决定
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # 先找到图片的路径,根目录配合idx
image = read_image(img_path) # 将图片加载到内存中
label = self.img_labels.iloc[idx, 1] # idx行第1列存储了标签
if self.transform: # 对该img进行一些变预处理:缩放、裁剪、归一化等
image = self.transform(image)
if self.target_transform: # 对label进行处理:变换为one-hot编码等
label = self.target_transform(label)
return image, label
DataLoader
将单个样本组织成一个批次,用于神经网络的训练
Preparing your data for training with DataLoaders
The Dataset
retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing
to speed up data retrieval.
DataLoader
is an iterable that abstracts this complexity for us in an easy API.
一次性加载minbatch的训练样本,并且每次加载时打乱顺序
dataset: Dataset[T_co] # dataset的一个实例化对象
batch_size: Optional[int] # 默认1
shuffle: Optional[bool] = None # 在每个训练周期结束后,是否需要对数据进行打乱
sampler: Union[Sampler, Iterable, None] = None # 怎样对数据进行采样
num_workers: int # 默认为0,使用主进程加载数据
pin_memory: bool # 把tensor保存在GPU中,不需要每次进行重复的保存
drop_last: bool # 如果数据不是batch_size整数倍时,True则是将最后一个小批次的数据丢掉
collate_fn: Optional[_collate_fn_t] = None # 对sampler采样的小批次数据后处理input:batch out:batch;每个样本标签的数量不一致,则需要扩充
timeout: float
sampler: Union[Sampler, Iterable]
pin_memory_device: str
prefetch_factor: Optional[int]
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
# 主要做了3件事: 构建单样本的sampler、构建单样本组合的batch_sampler、构建collate
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = ""):
torch._C._log_api_usage_once("python.data_loader")
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) # test时,shuffle一般不需要设置
Iterate through the DataLoader
We have loaded that dataset into the DataLoader
and can iterate through the dataset as needed. Each iteration below returns a batch of train_features
and train_labels
(containing batch_size=64
features and labels respectively). Because we specified shuffle=True
, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers).
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Labels batch shape: {train_labels.size()}") # Labels batch shape: torch.Size([64])
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}") Label: 5
# 样本级别的采样
if sampler is not None and shuffle: # 如果自定义了samlper则不需要shuffle
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
# batch级别的采样,如果设置了batch_samlper,则不需要设置batch_size、shuffle、sampler、drop_last
if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
# 以某种顺序从dataset中取样本
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle: # True,则使用内置的RandomSampler
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] 返回随机的索引列表
else: # 否则使用SequentialSampler:按照dataset原本的顺序取出数据构成minibatch
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
# 默认的创建一个batch_sampler
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
RandomSampler
def __iter__(self) -> Iterator[int]:
n = len(self.data_source) # 获取dataset的长度,也就是数据集的大小
if self.generator is None: # 如果没有传入generator
seed = int(torch.empty((), dtype=torch.int64).random_().item()) # 随机生成一个种子
generator = torch.Generator() # 构建torch的Generator
generator.manual_seed(seed) # 设置generator的种子
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy())
final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator)
yield from map(int, final_samples.numpy())
else:
for _ in range(self.num_samples // n):
yield from map(int, torch.randperm(n, generator=generator).numpy())
yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy()) # 返回0~n-1的list的一个随机组合[0, 1, 2]=>[1, 0, 2]...
SequentialSampler
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source))) # 返回有序的索引
batch_sampler从dataset中以sampler方式取出的数据,拼成一个batch然后返回idx索引。
def __iter__(self) -> Iterator[List[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size # 新建一个空数组
idx_in_batch = 0
for idx in self.sampler: # 从sampler中取元素的idx
batch[idx_in_batch] = idx # 将读取到的idx添加到batch的数组中
idx_in_batch += 1
if idx_in_batch == self.batch_size: # 如果数组的长度==batch_size
yield batch # 返回该batch
idx_in_batch = 0 # 索引归0
batch = [0] * self.batch_size # batch置空,再循环
if idx_in_batch > 0:
yield batch[:idx_in_batch]
collate_fn
if collate_fn is None:
if self._auto_collation: # 假设走这个分支
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
#############################################################
@property
def _auto_collation(self): # 根据batch_sampler是否是None来设置
return self.batch_sampler is not None
default_collate
def default_collate(batch):
return collate(batch, collate_fn_map=default_collate_fn_map)
def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
elem = batch[0]
elem_type = type(elem)
if collate_fn_map is not None:
if elem_type in collate_fn_map:
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
for collate_type in collate_fn_map:
if isinstance(elem, collate_type):
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
if isinstance(elem, collections.abc.Mapping):
try:
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
if isinstance(elem, tuple):
return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
else:
try:
return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
# 将元素重新组合起来,返回batch
return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
train_dataloader为何能成为迭代器?
# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 为何能成为迭代器?走的下方__iter__逻辑
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
dataloader里
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
def __iter__(self) -> '_BaseDataLoaderIter': # 实现了该函数,类的实例化对象就可以在前边加一个iter函数,将其变为迭代器
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
预知iter(train_dataloader)
是干啥的?
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# Taking care of distributed sharding
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
# For BC, use default SHARDING_PRIORITIES
torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
# 创建fetcher,可以从dataset中取值
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # 得到索引
data = self._dataset_fetcher.fetch(index) # fetch索引,得到数据
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data # 最终返回数据
class _BaseDataLoaderIter:
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._shared_seed = None
self._pg = None
if isinstance(self._dataset, IterDataPipe):
if dist.is_available() and dist.is_initialized():
self._pg = dist.new_group(backend="gloo")
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
ws, rank = _get_distributed_settings()
self._world_size = ws
self._rank = rank
if (len(loader.pin_memory_device) == 0):
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._pin_memory_device = None
else:
if not loader.pin_memory:
warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
"please set pin_memory to true, if you need to use the device pin memory")
warnings.warn(warn_msg)
self._pin_memory = loader.pin_memory
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
if isinstance(self._dataset, IterDataPipe):
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError # 需要子类实现
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset() # type: ignore[call-arg]
data = self._next_data() # 有被调用,得到data
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data # 最终返回data
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
所以,iter(train_dataloader)
是调用的__iter__
得到_BaseDataLoaderIter
实例,在该实例中,又调用了__next__
方法,
里边又调用了_next_data()
得到data
,_SingleProcessDataLoaderIter
内虽然没有实现__next__
方法,但是其实现了_next_data()
,因为在基类中,我们是调用_next_data(self)
再返回到next
方法中,于是就可以next(iter(train_dataloader))
得到一个minibatch的数据。
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
def __len__(self) -> int:
if self._dataset_kind == _DatasetKind.Iterable:
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
if self.batch_size is not None: # IterableDataset doesn't allow custom sampler or batch_sampler
from math import ceil
if self.drop_last:
length = length // self.batch_size
else:
length = ceil(length / self.batch_size)
return length
else:
return len(self._index_sampler)
返回batch_sampler,dataloader的长度就是基于该采样器的长度计算出来的。next(iter(train_dataloader))
的next()
能调用多少次?根据dataloader的长度和采样器的大小决定的。
还不快抢沙发