pytorch训练时使用多个dataloader

在训练神经网络时我们可能会同时使用多个dataloader,则需要在原来的enumerate(dataloader)上加入zip函数:

for i, data in enumerate(zip(dataloader1, dataloader2)):
    pass

此时,data是一个(2,2)的元组,第一行是dataloader1的data和label,第二行是dataloader2的data和label。

另外,dataloader1和dataloader2的大小很有可能不一样,即len(dataloader1) != len(dataloader2),则它会以数量最少的那个dataloader为标准停止,例如len(dataloader1)=85len(dataloader2)=80,则最终的i就是80。并且dataloader2的最后一个batch的大小可能不够一个batch size。

这时,我们可以调用:

from itertools import cycle
for i, data in enumerate(zip(dataloader1, cycle(dataloader2))):
    pass

dataloader2就又会从头开始循环,直到将dataloader1也循环完。
但是你有没有发现,这是另一个死循环啊~!dataloader1的最后一个batch的数据数量也不一定等于batch size。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容