Python TensorFlow:使用张量索引另一个张量

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/35842598/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 17:03:10  来源:igfitidea点击:

TensorFlow: using a tensor to index another tensor

pythonnumpytensorflow

提问by user200340

I have a basic question about how to do indexing in TensorFlow.

我有一个关于如何在 TensorFlow 中建立索引的基本问题。

In numpy:

在 numpy 中:

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#numpy 
print x * e[x]

I can get

我可以得到

[1 0 3 3 0 5 0 7 1 3]

How can I do this in TensorFlow?

我怎样才能在 TensorFlow 中做到这一点?

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
x_t = tf.constant(x)
e_t = tf.constant(e)
with tf.Session():
    ????

Thanks!

谢谢!

回答by mrry

Fortunately, the exact case you're asking about is supported in TensorFlow by tf.gather():

幸运的是,TensorFlow 通过tf.gather()以下方式支持您所询问的确切情况:

result = x_t * tf.gather(e_t, x_t)

with tf.Session() as sess:
    print sess.run(result)  # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'

The tf.gather()op is less powerful than NumPy's advanced indexing: it only supports extracting full slices of a tensor on its 0th dimension. Support for more general indexing has been requested, and is being tracked in this GitHub issue.

tf.gather()运算是小于强大NumPy的先进索引:它仅支持其零维提取张量的全片。已请求支持更一般的索引,并在此 GitHub 问题 中进行跟踪。