Python 如何获得比 numpy.dot 更快的矩阵乘法代码?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/19839539/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to get faster code than numpy.dot for matrix multiplication?
提问by mrgloom
Here Matrix multiplication using hdf5I use hdf5 (pytables) for big matrix multiplication, but I was suprised because using hdf5 it works even faster then using plain numpy.dot and store matrices in RAM, what is the reason of this behavior?
这里使用 hdf5 矩阵乘法我使用 hdf5 (pytables) 进行大矩阵乘法,但我很惊讶,因为使用 hdf5 它比使用普通 numpy.dot 工作得更快,并将矩阵存储在 RAM 中,这种行为的原因是什么?
And maybe there is some faster function for matrix multiplication in python, because I still use numpy.dot for small block matrix multiplication.
也许在 python 中有一些更快的矩阵乘法函数,因为我仍然使用 numpy.dot 进行小块矩阵乘法。
here is some code:
这是一些代码:
Assume matrices can fit in RAM: test on matrix 10*1000 x 1000.
假设矩阵可以放入 RAM:在矩阵 10*1000 x 1000 上进行测试。
Using default numpy(I think no BLAS lib). Plain numpy arrays are in RAM: time 9.48
使用默认的 numpy(我认为没有 BLAS 库)。普通的 numpy 数组在 RAM 中:时间 9.48
If A,B in RAM, C on disk: time 1.48
如果 A,B 在 RAM 中,C 在磁盘上:时间 1.48
If A,B,C on disk: time 372.25
如果 A,B,C 在磁盘上:时间 372.25
If I use numpy with MKL results are: 0.15,0.45,43.5.
如果我将 numpy 与 MKL 一起使用,结果是:0.15、0.45、43.5。
Results looks reasonable, but I still don't understand why in 1st case block multiplication is faster(when we store A,B in RAM).
结果看起来合理,但我仍然不明白为什么在第一种情况下块乘法更快(当我们将 A、B 存储在 RAM 中时)。
n_row=1000
n_col=1000
n_batch=10
def test_plain_numpy():
A=np.random.rand(n_row,n_col)# float by default?
B=np.random.rand(n_col,n_row)
t0= time.time()
res= np.dot(A,B)
print (time.time()-t0)
#A,B in RAM, C on disk
def test_hdf5_ram():
rows = n_row
cols = n_col
batches = n_batch
#using numpy array
A=np.random.rand(n_row,n_col)
B=np.random.rand(n_col,n_row)
#settings for all hdf5 files
atom = tables.Float32Atom() #if store uint8 less memory?
filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
Nchunk = 128 # ?
chunkshape = (Nchunk, Nchunk)
chunk_multiple = 1
block_size = chunk_multiple * Nchunk
#using hdf5
fileName_C = 'CArray_C.h5'
shape = (A.shape[0], B.shape[1])
h5f_C = tables.open_file(fileName_C, 'w')
C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)
sz= block_size
t0= time.time()
for i in range(0, A.shape[0], sz):
for j in range(0, B.shape[1], sz):
for k in range(0, A.shape[1], sz):
C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
print (time.time()-t0)
h5f_C.close()
def test_hdf5_disk():
rows = n_row
cols = n_col
batches = n_batch
#settings for all hdf5 files
atom = tables.Float32Atom() #if store uint8 less memory?
filters = tables.Filters(complevel=9, complib='blosc') # tune parameters
Nchunk = 128 # ?
chunkshape = (Nchunk, Nchunk)
chunk_multiple = 1
block_size = chunk_multiple * Nchunk
fileName_A = 'carray_A.h5'
shape_A = (n_row*n_batch, n_col) # predefined size
h5f_A = tables.open_file(fileName_A, 'w')
A = h5f_A.create_carray(h5f_A.root, 'CArray', atom, shape_A, chunkshape=chunkshape, filters=filters)
for i in range(batches):
data = np.random.rand(n_row, n_col)
A[i*n_row:(i+1)*n_row]= data[:]
rows = n_col
cols = n_row
batches = n_batch
fileName_B = 'carray_B.h5'
shape_B = (rows, cols*batches) # predefined size
h5f_B = tables.open_file(fileName_B, 'w')
B = h5f_B.create_carray(h5f_B.root, 'CArray', atom, shape_B, chunkshape=chunkshape, filters=filters)
sz= rows/batches
for i in range(batches):
data = np.random.rand(sz, cols*batches)
B[i*sz:(i+1)*sz]= data[:]
fileName_C = 'CArray_C.h5'
shape = (A.shape[0], B.shape[1])
h5f_C = tables.open_file(fileName_C, 'w')
C = h5f_C.create_carray(h5f_C.root, 'CArray', atom, shape, chunkshape=chunkshape, filters=filters)
sz= block_size
t0= time.time()
for i in range(0, A.shape[0], sz):
for j in range(0, B.shape[1], sz):
for k in range(0, A.shape[1], sz):
C[i:i+sz,j:j+sz] += np.dot(A[i:i+sz,k:k+sz],B[k:k+sz,j:j+sz])
print (time.time()-t0)
h5f_A.close()
h5f_B.close()
h5f_C.close()
采纳答案by Fred Foo
np.dot
dispatches to BLASwhen
np.dot
发送到BLAS时
- NumPy has been compiled to use BLAS,
- a BLAS implementation is available at run-time,
- your data has one of the dtypes
float32
,float64
,complex32
orcomplex64
, and - the data is suitably aligned in memory.
- NumPy 已被编译为使用 BLAS,
- BLAS 实现在运行时可用,
- 您的数据具有 dtypes
float32
、float64
、complex32
或 之一complex64
,并且 - 数据在内存中适当对齐。
Otherwise, it defaults to using its own, slow, matrix multiplication routine.
否则,它默认使用自己的、缓慢的矩阵乘法例程。
Checking your BLAS linkage is described here. In short, check whether there's a file _dotblas.so
or similar in your NumPy installation. When there is, check which BLAS library it's linked against; the reference BLAS is slow, ATLAS is fast, OpenBLAS and vendor-specific versions such as Intel MKL are even faster. Watch out with multithreaded BLAS implementations as they don't play nicelywith Python's multiprocessing
.
此处描述了检查您的 BLAS 链接。简而言之,检查_dotblas.so
您的 NumPy 安装中是否有文件或类似文件。如果有,请检查它链接到哪个 BLAS 库;参考 BLAS 很慢,ATLAS 很快,OpenBLAS 和特定于供应商的版本(如 Intel MKL)甚至更快。小心多线程BLAS实现,因为他们没能很好的发挥与Python的multiprocessing
。
Next, check your data alignment by inspecting the flags
of your arrays. In versions of NumPy before 1.7.2, both arguments to np.dot
should be C-ordered. In NumPy >= 1.7.2, this doesn't matter as much anymore as special cases for Fortran arrays have been introduced.
接下来,通过检查flags
数组的来检查数据对齐。在 1.7.2 之前的 NumPy 版本中, to 的两个参数都np.dot
应该是 C 顺序的。在 NumPy >= 1.7.2 中,这不再重要,因为引入了 Fortran 数组的特殊情况。
>>> X = np.random.randn(10, 4)
>>> Y = np.random.randn(7, 4).T
>>> X.flags
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
>>> Y.flags
C_CONTIGUOUS : False
F_CONTIGUOUS : True
OWNDATA : False
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
If your NumPy is not linked against BLAS, either (easy) re-install it, or (hard) use the BLAS gemm
(generalized matrix multiply) function from SciPy:
如果您的 NumPy 未与 BLAS 相关联,请(简单)重新安装它,或(困难)使用gemm
SciPy 中的 BLAS (广义矩阵乘法)函数:
>>> from scipy.linalg import get_blas_funcs
>>> gemm = get_blas_funcs("gemm", [X, Y])
>>> np.all(gemm(1, X, Y) == np.dot(X, Y))
True
This looks easy, but it does hardly any error checking, so you must really know what you're doing.
这看起来很简单,但它几乎不进行任何错误检查,因此您必须真正了解自己在做什么。