Python TensorFlow:使用张量索引另一个张量
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/35842598/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
TensorFlow: using a tensor to index another tensor
提问by user200340
I have a basic question about how to do indexing in TensorFlow.
我有一个关于如何在 TensorFlow 中建立索引的基本问题。
In numpy:
在 numpy 中:
x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#numpy
print x * e[x]
I can get
我可以得到
[1 0 3 3 0 5 0 7 1 3]
How can I do this in TensorFlow?
我怎样才能在 TensorFlow 中做到这一点?
x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
x_t = tf.constant(x)
e_t = tf.constant(e)
with tf.Session():
????
Thanks!
谢谢!
回答by mrry
Fortunately, the exact case you're asking about is supported in TensorFlow by tf.gather()
:
幸运的是,TensorFlow 通过tf.gather()
以下方式支持您所询问的确切情况:
result = x_t * tf.gather(e_t, x_t)
with tf.Session() as sess:
print sess.run(result) # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'
The tf.gather()
op is less powerful than NumPy's advanced indexing: it only supports extracting full slices of a tensor on its 0th dimension. Support for more general indexing has been requested, and is being tracked in this GitHub issue.
该tf.gather()
运算是小于强大NumPy的先进索引:它仅支持其零维提取张量的全片。已请求支持更一般的索引,并在此 GitHub 问题 中进行跟踪。