Python 按列解压 NumPy 数组

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/27046533/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 01:22:05  来源:igfitidea点击:

Unpack NumPy array by column

pythonarraysnumpyargument-unpacking

提问by jeff_new

If I have a NumPy array, for example 5x3, is there a way to unpack it column by column all at once to pass to a function rather than like this: my_func(arr[:, 0], arr[:, 1], arr[:, 2])?

如果我有一个 NumPy 数组,例如 5x3,有没有办法一次一列一列地解压它以传递给一个函数,而不是像这样:my_func(arr[:, 0], arr[:, 1], arr[:, 2])

Kind of like *argsfor list unpacking but by column.

有点像*args列表解包,但按列。

采纳答案by Alex Riley

You can unpack the transpose of the array in order to use the columns for your function arguments:

您可以解压缩数组的转置,以便将列用于函数参数:

my_func(*arr.T)

Here's a simple example:

这是一个简单的例子:

>>> x = np.arange(15).reshape(5, 3)
array([[ 0,  5, 10],
       [ 1,  6, 11],
       [ 2,  7, 12],
       [ 3,  8, 13],
       [ 4,  9, 14]])

Let's write a function to add the columns together (normally done with x.sum(axis=1)in NumPy):

让我们编写一个函数来将列相加(通常x.sum(axis=1)在 NumPy 中完成):

def add_cols(a, b, c):
    return a+b+c

Then we have:

然后我们有:

>>> add_cols(*x.T)
array([15, 18, 21, 24, 27])

NumPy arrays will be unpacked along the first dimension, hence the need to transpose the array.

NumPy 数组将沿第一维解包,因此需要转置数组。

回答by Stephanie

numpy.splitsplits an array into multiple sub-arrays. In your case, indices_or_sectionsis 3 since you have 3 columns, and axis = 1since we're splitting by column.

numpy.split将数组拆分为多个子数组。在您的情况下,indices_or_sections是 3,因为您有 3 列,而且axis = 1我们按列拆分。

my_func(numpy.split(array, 3, 1))

回答by CookieMaster

I guess numpy.splitwill not suffice in the future. Instead, it should be

我想numpy.split将来是不够的。相反,它应该是

my_func(tuple(numpy.split(array, 3, 1)))

Currently, python prints the following warning:

目前,python 打印以下警告:

FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use arr[tuple(seq)]instead of arr[seq]. In the future this will be interpreted as an array index, arr[np.array(seq)], which will result either in an error or a different result.

FutureWarning:不推荐使用非元组序列进行多维索引;使用arr[tuple(seq)]代替arr[seq]. 将来,这将被解释为数组索引, arr[np.array(seq)],这将导致错误或不同的结果。