Python 查找 numpy 数组的 k 个最小值的索引

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/34226400/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 14:39:00  来源:igfitidea点击:

Find the index of the k smallest values of a numpy array

pythonnumpy

提问by Basj

In order to find the index of the smallest value, I can use argmin:

为了找到最小值的索引,我可以使用argmin

import numpy as np
A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
print A.argmin()     # 4 because A[4] = 0.1

但是我怎样才能找到 k-smallest valuesk-最小值?

I'm looking for something like:

我正在寻找类似的东西:

print A.argmin(numberofvalues=3)   
# [4, 0, 7]  because A[4] <= A[0] <= A[7] <= all other A[i]

Note: in my use case A has between ~ 10 000 and 100 000 values, and I'm interested for only the indices of the k=10 smallest values. k will never be > 10.

注意:在我的用例中,A 有大约 10 000 到 100 000 个值,我只对 k=10 最小值的索引感兴趣。k 永远不会 > 10。

采纳答案by unutbu

Use np.argpartition. It does not sort the entire array. It only guarantees that the kthelement is in sorted position and all smaller elements will be moved before it. Thus the first kelements will be the k-smallest elements.

使用np.argpartition. 它不会对整个数组进行排序。它只保证kth元素处于已排序的位置,并且所有较小的元素都将移动到它之前。因此,第一个k元素将是 k 最小的元素。

import numpy as np

A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
k = 3

idx = np.argpartition(A, k)
print(idx)
# [4 0 7 3 1 2 6 5]

This returns the k-smallest values. Note that these may not be in sorted order.

这将返回 k 最小值。请注意,这些可能不是按排序顺序排列的。

print(A[idx[:k]])
# [ 0.1  1.   1.5]


To obtain the k-largest values use

要获得 k 最大值,请使用

idx = np.argpartition(A, -k)
# [4 0 7 3 1 2 6 5]

A[idx[-k:]]
# [  9.  17.  17.]

WARNING: Do not (re)use idx = np.argpartition(A, k); A[idx[-k:]]to obtain the k-largest. That won't always work. For example, these are NOT the 3 largest values in x:

警告:不要(重新)使用idx = np.argpartition(A, k); A[idx[-k:]]来获得 k 最大。那不会总是奏效。例如,这些不是 中的 3 个最大值x

x = np.array([100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0])
idx = np.argpartition(x, 3)
x[idx[-3:]]
array([ 70,  80, 100])


Here is a comparison against np.argsort, which also works but just sorts the entire array to get the result.

这是与 的比较np.argsort,它也有效,但只是对整个数组进行排序以获得结果。

In [2]: x = np.random.randn(100000)

In [3]: %timeit idx0 = np.argsort(x)[:100]
100 loops, best of 3: 8.26 ms per loop

In [4]: %timeit idx1 = np.argpartition(x, 100)[:100]
1000 loops, best of 3: 721 μs per loop

In [5]: np.alltrue(np.sort(np.argsort(x)[:100]) == np.sort(np.argpartition(x, 100)[:100]))
Out[5]: True

回答by Cory Kramer

You can use numpy.argsortwith slicing

您可以numpy.argsort与切片一起使用

>>> import numpy as np
>>> A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
>>> np.argsort(A)[:3]
array([4, 0, 7], dtype=int32)

回答by Marcelo Villa-Pi?eros

numpy.partition(your_array, k)is an alternative. No slicing necessary as it gives the values sorted until the kthelement.

numpy.partition(your_array, k)是一个替代方案。不需要切片,因为它给出了排序到kth元素之前的值。

回答by Jeremiah England

For n-dimentional arrays, this function works well. The indecies are returned in a callable form. If you want a list of the indices to be returned, then you need to transpose the array before you make a list.

对于n 维数组,此函数运行良好。indecies 以可调用的形式返回。如果要返回索引列表,则需要在创建列表之前转置数组。

To retrieve the klargest, simply pass in -k.

要检索k最大的,只需传入-k.

def get_indices_of_k_smallest(arr, k):
    idx = np.argpartition(arr.ravel(), k)
    return tuple(np.array(np.unravel_index(idx, arr.shape))[:, range(min(k, 0), max(k, 0))])
    # if you want it in a list of indices . . . 
    # return np.array(np.unravel_index(idx, arr.shape))[:, range(k)].transpose().tolist()

Example:

例子:

r = np.random.RandomState(1234)
arr = r.randint(1, 1000, 2 * 4 * 6).reshape(2, 4, 6)

indices = get_indices_of_k_smallest(arr, 4)
indices
# (array([1, 0, 0, 1], dtype=int64),
#  array([3, 2, 0, 1], dtype=int64),
#  array([3, 0, 3, 3], dtype=int64))

arr[indices]
# array([ 4, 31, 54, 77])

%%timeit
get_indices_of_k_smallest(arr, 4)
# 17.1 μs ± 651 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)