Python 如何在 PyTorch 中显示单个图像?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/53623472/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How do I display a single image in PyTorch?
提问by Tom Hale
I want to display a single image. It was loaded using a ImageLoader
and is stored in a PyTorch Tensor
.
我想显示单个图像。它使用 a 加载ImageLoader
并存储在 PyTorch 中Tensor
。
When I try to display it via plt.imshow(image)
, I get:
当我尝试通过 显示它时plt.imshow(image)
,我得到:
TypeError: Invalid dimensions for image data
The .shape
of the tensor is:
该.shape
张量是:
torch.Size([3, 244, 244])
How do I display the image contained in this PyTorch tensor?
如何显示此 PyTorch 张量中包含的图像?
回答by Tom Hale
Given a Tensor
representing the image, use .permute()
to put the channels as the last dimension:
给定Tensor
表示图像的 a,用于.permute()
将通道作为最后一个维度:
plt.imshow( tensor_image.permute(1, 2, 0) )
Note: permute
does not copy or allocate memory, and from_numpy()
doesn't either.
回答by trsvchn
As you can see matplotlib
works fine even without conversion to numpy
array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib
you need to reshape it:
如您所见,matplotlib
即使不转换为numpy
数组也能正常工作。但是 PyTorch 张量(“图像张量”)首先是通道,因此要使用它们,matplotlib
您需要对其进行重塑:
Code:
代码:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Output:
输出:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
回答by Tom Hale
A complete example given an image pathname img_path
:
给定图像路径名的完整示例img_path
:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
Note that transforms.*
return a function, which is why the funky bracketing.
请注意,transforms.*
返回一个函数,这就是时髦括号的原因。
回答by Tom Hale
Given the image is loaded as described and stored in the variable image
:
鉴于图像按描述加载并存储在变量中image
:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
The matplotlib
image tutorialsays:
Bicubic interpolation is often used when blowing up photos - people tend to prefer blurry over pixelated.
放大照片时经常使用双三次插值 - 人们往往更喜欢模糊而不是像素化。
Or as Soumith suggested:
或者像Soumith 建议的那样:
%matplotlib inline
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
Or, to open the image in a popup window:
或者,在弹出窗口中打开图像:
transforms.ToPILImage()(image).show()