Python 如何在 PyTorch 中做矩阵的乘积
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/44524901/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to do product of matrices in PyTorch
提问by blckbird
In numpy I can do a simple matrix multiplication like this:
在 numpy 中,我可以做一个简单的矩阵乘法,如下所示:
a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))
However, when I am trying this with PyTorch Tensors, this does not work:
但是,当我使用 PyTorch Tensors 尝试此操作时,这不起作用:
a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())
print(b)
print(b.size())
print(torch.dot(a, b))
This code throws the following error:
此代码引发以下错误:
RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503
运行时错误:/Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503 处的张量大小不一致
Any ideas how matrix multiplication can be conducted in PyTorch?
任何想法如何在 PyTorch 中进行矩阵乘法?
回答by mexmex
You're looking for
您正在寻找
torch.mm(a,b)
Note that torch.dot()
behaves differently to np.dot()
. There's been some discussion about what would be desirable here. Specifically, torch.dot()
treats both a
and b
as 1D vectors (irrespective of their original shape) and computes their inner product. The error is thrown, because this behaviour makes your a
a vector of length 6 and your b
a vector of length 2; hence their inner product can't be computed. For matrix multiplication in PyTorch, use torch.mm()
. Numpy's np.dot()
in contrast is more flexible; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.
请注意, 的torch.dot()
行为与 不同np.dot()
。有一些关于什么是可取的在这里讨论。具体来说,torch.dot()
将a
和 都b
视为一维向量(不管它们的原始形状如何)并计算它们的内积。抛出错误,因为这种行为使您a
成为长度为 6b
的向量和长度为 2 的向量;因此无法计算它们的内积。对于 PyTorch 中的矩阵乘法,使用torch.mm()
. np.dot()
相比之下,Numpy更灵活;它计算一维数组的内积并为二维数组执行矩阵乘法。
By popular demand, the function torch.matmul
performs matrix multiplications if both arguments are 2D
and computes their dot product if both arguments are 1D
. For inputs of such dimensions, its behaviour is the same as np.dot
. It also lets you do broadcasting or matrix x matrix
, matrix x vector
and vector x vector
operations in batches. For more info, see its docs.
根据普遍的需求,torch.matmul
如果两个参数都是,则该函数执行矩阵乘法,如果两个参数都是2D
,则计算它们的点积1D
。对于此类维度的输入,其行为与 相同np.dot
。它还允许您批量进行广播或matrix x matrix
,matrix x vector
和vector x vector
操作。有关更多信息,请参阅其文档。
# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])
# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])
回答by BiBi
If you want to do a matrix (rank 2 tensor) multiplication you can do it in four equivalent ways:
如果你想做一个矩阵(秩 2 张量)乘法,你可以用四种等效的方式来做:
AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+
There are a few subtleties. From the PyTorch documentation:
有一些微妙之处。从PyTorch 文档:
torch.mm does not broadcast. For broadcasting matrix products, see torch.matmul().
torch.mm 不广播。对于广播矩阵产品,请参见 torch.matmul()。
For instance, you cannot multiply two 1-dimensional vectors with torch.mm
, nor multiply batched matrices (rank 3). To this end, you should use the more versatile torch.matmul
. For an extensive list of the broadcasting behaviours of torch.matmul
, see the documentation.
例如,您不能将两个一维向量与torch.mm
相乘,也不能将批处理矩阵(等级 3)相乘。为此,您应该使用更通用的torch.matmul
. 有关 的广播行为的详细列表torch.matmul
,请参阅文档。
For element-wise multiplication, you can simply do (if A and B have the same shape)
对于逐元素乘法,您可以简单地执行(如果 A 和 B 具有相同的形状)
A * B # element-wise matrix multiplication (Hadamard product)
回答by David Jung
Use torch.mm(a, b)
or torch.matmul(a, b)
Both are same.
用途torch.mm(a, b)
或torch.matmul(a, b)
两者相同。
>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>
There's one more option that may be good to know.
That is @
operator. @Simon H.
还有一个可能很高兴知道的选项。那就是@
运营商。@西蒙H。
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
The three give the same results.
三者给出相同的结果。
Related links:
Matrix multiplication operator
PEP 465 -- A dedicated infix operator for matrix multiplication