如何使用PyTorch torch.max()

时间:2020-02-23 14:43:43  来源:igfitidea点击:

在本文中,我们将介绍如何使用PyTorch torch.max()函数。

如您所料,这是一个非常简单的功能,但有趣的是,它具有的功能超出了您的想象。

让我们使用一些简单的示例来看看如何使用此功能。

注意:在撰写本文时,使用的PyTorch版本是PyTorch 1.5.0。

PyTorch torch.max()–基本语法

要使用PyTorchtorch.max(),首先导入torch

import torch

现在,此函数返回张量中元素之间的最大值。

PyTorch torch.max()的默认行为

默认行为是返回单个元素和对应于全局最大元素的索引。

max_element = torch.max(input_tensor)

这是一个例子:

p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)

输出

tensor([[-0.0665,  2.7976,  0.9753],
      [ 0.0688, -1.0376,  1.4443]])
tensor(2.7976)

确实,这使我们在张量中具有全局最大元素!

沿尺寸使用torch.max()

但是,您可能希望沿特定维度获得最大的张量,而不是单个元素。

要指定尺寸(轴–在" numpy"中),还有另一个可选的关键字参数,称为" dim"。

这代表了我们追求最大的方向。

这将返回一个元组max_elements和max_indices。

  • max_elements->张量的所有最大元素。

  • max_indices->对应于最大元素的索引。

max_elements, max_indices = torch.max(input_tensor, dim)

这将返回一个Tensor,它在维度" dim"上具有最大的元素。

现在来看一些示例。

p = torch.randn([2, 3])
print(p)

# Get the maximum along dim = 0 (axis = 0)
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)

输出

tensor([[-0.0665,  2.7976,  0.9753],
      [ 0.0688, -1.0376,  1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])

如您所见,我们在维度0上找到最大值(在列上找到最大值)。

同样,我们获得与元素相对应的索引。
例如," 0.0688"沿第0列的索引为" 1"

同样,如果要在行中查找最大值,请使用dim = 1

# Get the maximum along dim = 1 (axis = 1)
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)

输出

tensor([2.7976, 1.4443])
tensor([1, 2])

实际上,我们获得了沿着该行的最大元素以及相应的索引(沿着该行)。

使用torch.max()进行比较

我们还可以使用" torch.max()"来获取两个张量之间的最大值。

output_tensor = torch.max(a, b)

其中" a"和" b"必须具有相同的尺寸,或者必须是"可广播的"张量。

这是比较两个具有相同尺寸的张量的简单示例。

p = torch.randn([2, 3])
q = torch.randn([2, 3])

print("p =", p)
print("q =",q)

# Compare elements of p and q and get the maximum
max_elements = torch.max(p, q)

print(max_elements)

输出

p = tensor([[-0.0665,  2.7976,  0.9753],
      [ 0.0688, -1.0376,  1.4443]])
q = tensor([[-0.0678,  0.2042,  0.8254],
      [-0.1530,  0.0581, -0.3694]])
tensor([[-0.0665,  2.7976,  0.9753],
      [ 0.0688,  0.0581,  1.4443]])

确实,我们得到的输出张量具有在p和q之间的最大元素。