如何使用PyTorch torch.max()
时间:2020-02-23 14:43:43 来源:igfitidea点击:
在本文中,我们将介绍如何使用PyTorch torch.max()函数。
如您所料,这是一个非常简单的功能,但有趣的是,它具有的功能超出了您的想象。
让我们使用一些简单的示例来看看如何使用此功能。
注意:在撰写本文时,使用的PyTorch版本是PyTorch 1.5.0。
PyTorch torch.max()–基本语法
要使用PyTorchtorch.max()
,首先导入torch
。
import torch
现在,此函数返回张量中元素之间的最大值。
PyTorch torch.max()的默认行为
默认行为是返回单个元素和对应于全局最大元素的索引。
max_element = torch.max(input_tensor)
这是一个例子:
p = torch.randn([2, 3]) print(p) max_element = torch.max(p) print(max_element)
输出
tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) tensor(2.7976)
确实,这使我们在张量中具有全局最大元素!
沿尺寸使用torch.max()
但是,您可能希望沿特定维度获得最大的张量,而不是单个元素。
要指定尺寸(轴–在" numpy"中),还有另一个可选的关键字参数,称为" dim"。
这代表了我们追求最大的方向。
这将返回一个元组max_elements和max_indices。
max_elements->张量的所有最大元素。
max_indices->对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)
这将返回一个Tensor,它在维度" dim"上具有最大的元素。
现在来看一些示例。
p = torch.randn([2, 3]) print(p) # Get the maximum along dim = 0 (axis = 0) max_elements, max_idxs = torch.max(p, dim=0) print(max_elements) print(max_idxs)
输出
tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) tensor([0.0688, 2.7976, 1.4443]) tensor([1, 0, 1])
如您所见,我们在维度0上找到最大值(在列上找到最大值)。
同样,我们获得与元素相对应的索引。
例如," 0.0688"沿第0列的索引为" 1"
同样,如果要在行中查找最大值,请使用dim = 1
。
# Get the maximum along dim = 1 (axis = 1) max_elements, max_idxs = torch.max(p, dim=1) print(max_elements) print(max_idxs)
输出
tensor([2.7976, 1.4443]) tensor([1, 2])
实际上,我们获得了沿着该行的最大元素以及相应的索引(沿着该行)。
使用torch.max()进行比较
我们还可以使用" torch.max()"来获取两个张量之间的最大值。
output_tensor = torch.max(a, b)
其中" a"和" b"必须具有相同的尺寸,或者必须是"可广播的"张量。
这是比较两个具有相同尺寸的张量的简单示例。
p = torch.randn([2, 3]) q = torch.randn([2, 3]) print("p =", p) print("q =",q) # Compare elements of p and q and get the maximum max_elements = torch.max(p, q) print(max_elements)
输出
p = tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, -1.0376, 1.4443]]) q = tensor([[-0.0678, 0.2042, 0.8254], [-0.1530, 0.0581, -0.3694]]) tensor([[-0.0665, 2.7976, 0.9753], [ 0.0688, 0.0581, 1.4443]])
确实,我们得到的输出张量具有在p和q之间的最大元素。