将 PyTorch 张量转换为 python 列表
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/53903373/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Convert PyTorch tensor to python list
提问by Tom Hale
How do I convert a PyTorch Tensor
into a python list?
如何将 PyTorchTensor
转换为 Python 列表?
My current use case is to convert a tensor of size [1, 2048, 1, 1]
into a list of 2048 elements.
我当前的用例是将大小的张量转换[1, 2048, 1, 1]
为 2048 个元素的列表。
My tensor has floating point values. Is there a solution which also accounts for int and possibly other data types?
我的张量有浮点值。是否有解决方案也考虑 int 和可能的其他数据类型?
回答by Tom Hale
I found Tensor.tolist()
which gives the following usage example:
我发现Tensor.tolist()
它给出了以下用法示例:
>>> import torch
>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
[-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803
So, to answer the question, use a.squeeze().tolist()
to remove all dimensions of size 1
.
因此,要回答这个问题,请使用a.squeeze().tolist()
删除 size 的所有维度1
。
Also consider .flatten()
if a list of lists is not desired.
还要考虑.flatten()
是否不需要列表列表。
Before I came across .tolist()
, I was using:
在我遇到之前.tolist()
,我正在使用:
list = [element.item() for element in tensor.flatten()]
This flattens the tensor into a single dimension then calls .item()
to convert each element into a Python number.
这会将张量展平为单个维度,然后调用.item()
将每个元素转换为 Python 数字。