Pytorch数据读取DataLoader和Dataset
1. DataLoader
DataLoader包括Sampler(用于生成索引)和Dataset(根据索引读取图片,标签)
DataLoader用于构建可迭代的数据装载器
1 | torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None) |
2. Dataset
Dataset是抽象类,所有自定义的Dataset均需要继承该类,并且重写__getitem()方法。\getitem()__方法的作用是接收一个索引,返回索引对应的样本和标签,这是我们自己需要实现的逻辑。
3. 读取数据过程
首先在for循环中使用DataLoader,根据是否需要多进程读取数据选择DataLoaderIter,然后通过Sampler读取index列表,在DatasetFetcher中调用Dataset的getitem方法,根据索引index从硬盘读取(图片img,标签label)列表,再通过collate_fn将列表分为batch,batch由图片img列表和标签label列表组成。
详细过程如下描述:
首先进入for循环
判断单进程还是多进程,程序中使用的单进程,因此进入_SingleProcessDataLoaderIter
初始化_SingleProcessDataLoaderIter类之后,跳转到了DataLoader的__next__()方法。
跳转到_SingleProcessDataLoaderIter类的_next_data()方法,其中_next_index()方法会调用Sampler类的__iter__()方法,用来读取index列表。
将index列表输入到DatasetFetcher类中,通过dataset[idx]获取图片和标签,dataset[idx]会调用Dataset的__getitem__()方法。我们在RMBDataset中已经写好了从硬盘读取图片和标签的代码。data是由一个batch的(图片,标签)列表组成。
data在经过collate_fn处理后,转变为两个列表,分别是图片列表和标签列表