Python Pytorch Tensor 如何获取特定值的索引

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/47863001/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 18:25:01  来源:igfitidea点击:

How Pytorch Tensor get the index of specific value

pythonpytorch

提问by Han Bing

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

在 python 列表中,我们可以使用list.index(somevalue). pytorch 如何做到这一点?
例如:

    a=[1,2,3]
    print(a.index(2))

Then, 1will be output. How can a pytorch tensor do this without converting it to a python list?

然后,1将被输出。pytorch 张量如何在不将其转换为 python 列表的情况下执行此操作?

回答by Manuel Lagunas

I think there is no direct translation from list.index()to a pytorch function. However, you can achieve similar results using tensor==numberand then the nonzero()function. For example:

我认为没有直接转换list.index()为 pytorch 函数。但是,您可以使用tensor==numberand thennonzero()函数获得类似的结果。例如:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())

This piece of code returns

这段代码返回

1

[torch.LongTensor of size 1x1]

1

[大小为1x1的torch.LongTensor]

回答by vlad

Can be done by converting to numpy as follows

可以通过转换为 numpy 来完成,如下所示

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

回答by Giang Nguy?n

For floating point tensors, I use this to get the index of the element in the tensor.

对于浮点张量,我使用它来获取张量中元素的索引。

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.

这里我想获取浮点张量中max_value的索引,你也可以像这样放置你的值来获取张量中任何元素的索引。

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())

回答by Mohanraj

    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])