Python 在 Tensorflow 中,如何将 tf.gather() 用于最后一个维度?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/36764791/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
In Tensorflow, how to use tf.gather() for the last dimension?
提问by YW P Kwon
I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth]
, I want to select slices based on the last dimension, such as
我试图根据层之间的部分连接的最后一个维度来收集张量的切片。因为输出张量的形状是[batch_size, h, w, depth]
,所以我想根据最后一个维度选择切片,例如
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
However, tf.gather(L, [0, 2,3,8])
seems to only work for the first dimension (right?) Can anyone tell me how to do it?
然而,tf.gather(L, [0, 2,3,8])
似乎只适用于第一维(对吧?)谁能告诉我怎么做?
采纳答案by Yaroslav Bulatov
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
这里有一个跟踪错误来支持这个用例:https: //github.com/tensorflow/tensorflow/issues/206
For now you can:
现在你可以:
transpose your matrix so that dimension to gather is first (transpose is expensive)
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
- use
gather_nd
. Will still need to turn your column indices into list of individual element indices.
转置您的矩阵,以便首先收集维度(转置很昂贵)
将您的张量重塑为 1d(重塑很便宜)并将您的聚集列索引转换为线性索引处的单个元素索引列表,然后重新整形
- 使用
gather_nd
. 仍然需要将您的列索引转换为单个元素索引的列表。
回答by rryan
As of TensorFlow 1.3 tf.gather
has an axis
parameter, so the various workarounds here are no longer necessary.
由于 TensorFlow 1.3tf.gather
有一个axis
参数,因此不再需要这里的各种解决方法。
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gatherhttps://github.com/tensorflow/tensorflow/issues/11223
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
回答by Andrei Pokrovsky
With gather_nd you can now do this as follows:
使用gather_nd,您现在可以按如下方式执行此操作:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Also, as reported by user Nova in a thread referenced by @Yaroslav Bulatov's:
此外,正如用户 Nova 在@Yaroslav Bulatov 引用的线程中所报告的那样:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
The gist is flatten the tensor and use strided 1D addressing with tf.gather(...).
要点是展平张量并使用带 tf.gather(...) 的跨步一维寻址。
回答by Yunseong Hwang
Yet another solution using tf.unstack(...), tf.gather(...) and tf.stack(..)
使用 tf.unstack(...)、tf.gather(...) 和 tf.stack(..) 的另一种解决方案
Code:
代码:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
Result:
结果:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]
回答by Edward Hughes
A correct version of @Andrei's answer would read
@Andrei 答案的正确版本应该是
cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
回答by Lerner Zhang
You can try this way, for instance(in most cases in NLP at the least),
例如,您可以尝试这种方式(至少在 NLP 的大多数情况下),
The parameter is of shape [batch_size, depth]
and the indices are [i, j, k, n, m] of which the length is batch_size. Then gather_nd
can be helpful.
参数为形状[batch_size, depth]
,索引为 [i, j, k, n, m],长度为 batch_size。然后gather_nd
可以有帮助。
parameters = tf.constant([
[11, 12, 13],
[21, 22, 23],
[31, 32, 33],
[41, 42, 43]])
targets = tf.constant([2, 1, 0, 1])
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number
items = tf.gather_nd(parameters, indices)
# which is what we want: [13, 22, 31, 42]
This snippet first find the fist dimension through the batch_num and then fetch the item along that dimension by the target number.
此代码段首先通过 batch_num 找到第一个维度,然后通过目标编号沿该维度获取项目。
回答by Adwin Jahn
Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7
Tensor 没有属性 shape,而是 get_shape() 方法。下面是可由 Python 2.7 运行的
import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
回答by Sven Dorkenwald
Implementing 2. from @Yaroslav Bulatov's:
实施 2. 来自@Yaroslav Bulatov 的:
#Your indices
indices = [0, 2, 3, 8]
#Remember for final reshaping
n_indices = tf.shape(indices)[0]
flattened_L = tf.reshape(L, [-1])
#Walk strided over the flattened array
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
selected_rows = tf.gather(flattened_L, flattened_indices)
#Final reshape
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))
Credit to How to select rows from a 3-D Tensor in TensorFlow?