Python 如何将图像加载到 Pytorch DataLoader 中?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/50052295/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 19:22:28  来源:igfitidea点击:

How do you load images into Pytorch DataLoader?

pythonpytorch

提问by Terry

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?

用于数据加载和处理的 pytorch 教程非常特定于一个示例,有人可以帮助我了解更通用的简单图像加载功能应该是什么样的吗?

Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

教程:http: //pytorch.org/tutorials/beginner/data_loading_tutorial.html

My Data:

我的数据:

I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).

我在以下文件夹结构中有 MINST 数据集作为 jpg。(我知道我只能使用数据集类,但这纯粹是为了了解如何将简单的图像加载到没有 csv 或复杂功能的 pytorch 中)。

The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.

文件夹名称是标签,图像是灰度的 28x28 png,不需要转换。

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

回答by Duane

Here's what I did for pytorch 0.4.1 (should still work in 1.3)

这是我为 pytorch 0.4.1 所做的(应该仍然适用于 1.3)

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network

回答by Ari K

If you're using mnist, there's already a preset in pytorch via torchvision.
You could do

如果您使用的是 mnist,那么 pytorch 中已经有一个通过 torchvision 的预设。
你可以做

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

If you want to generalize to a directory of images (same imports as above), you could do

如果你想推广到一个图像目录(与上面相同的导入),你可以这样做

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

You can then iterate over it like:

然后你可以像这样迭代它:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset

有很多方法可以将 pytorch 泛化为图像数据集加载,我所知道的方法是子类化torch.utils.data.dataset