Python 如何将图像加载到 Pytorch DataLoader 中?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/50052295/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How do you load images into Pytorch DataLoader?
提问by Terry
The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?
用于数据加载和处理的 pytorch 教程非常特定于一个示例,有人可以帮助我了解更通用的简单图像加载功能应该是什么样的吗?
Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
教程:http: //pytorch.org/tutorials/beginner/data_loading_tutorial.html
My Data:
我的数据:
I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).
我在以下文件夹结构中有 MINST 数据集作为 jpg。(我知道我只能使用数据集类,但这纯粹是为了了解如何将简单的图像加载到没有 csv 或复杂功能的 pytorch 中)。
The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.
文件夹名称是标签,图像是灰度的 28x28 png,不需要转换。
data
train
0
3.png
5.png
13.png
23.png
...
1
3.png
10.png
11.png
...
2
4.png
13.png
...
3
8.png
...
4
...
5
...
6
...
7
...
8
...
9
...
回答by Duane
Here's what I did for pytorch 0.4.1 (should still work in 1.3)
这是我为 pytorch 0.4.1 所做的(应该仍然适用于 1.3)
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network
回答by Ari K
If you're using mnist, there's already a preset in pytorch via torchvision.
You could do
如果您使用的是 mnist,那么 pytorch 中已经有一个通过 torchvision 的预设。
你可以做
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
shuffle=True, num_workers=2)
If you want to generalize to a directory of images (same imports as above), you could do
如果你想推广到一个图像目录(与上面相同的导入),你可以这样做
class mnistmTrainingDataset(torch.utils.data.Dataset):
def __init__(self,text_file,root_dir,transform=transformMnistm):
"""
Args:
text_file(string): path to text file
root_dir(string): directory with all train images
"""
self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.name_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
image = Image.open(img_name)
image = self.transform(image)
labels = self.label_frame.iloc[idx, 0]
#labels = labels.reshape(-1, 2)
sample = {'image': image, 'labels': labels}
return sample
mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
root_dir = 'Downloads/mnist_m/mnist_m_train')
mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
You can then iterate over it like:
然后你可以像这样迭代它:
for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
print("training sample for mnist-m")
print(i_batch,sample_batched['image'],sample_batched['labels'])
There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset
有很多方法可以将 pytorch 泛化为图像数据集加载,我所知道的方法是子类化torch.utils.data.dataset