[Numpy-discussion] matrixmultiply implementation

Christopher Lee clee at gnwy100.wuh.wustl.edu
Fri Dec 1 19:43:12 EST 2000


I'm curious to see if anyone has looked at the implementation of
dot/matrixmultiply recently.  I was doing a benchmark lapack-based matrix
inversion and was surprised to find that the accuracy check was taking
longer than the inversion.  The bottleneck was the matrix multiply.  It
appears that a blas dgemm implementation might give a large improvement.
Here is a sample benchmark using PyLapack's dblas.dgemm for comparison.
Results for a 1000 x 1000 matrix inversion:

  time elapsed for inverse (sec) 10.050768
  time elapsed for matrixmultiply (sec) 24.142714
  time elapsed for dgemm-matrixmultiply (sec) 5.997188
  max error is: 3.30398486348e-10

Given the "dot" code doesn't look too bad, so I might be able to add dgemm
support myself, at least for some cases.

-chris

p.s. I'll include the test code below
#####################################
import time
from Numeric import *
from LinearAlgebra import *
import dblas

for N in [5,100, 1000]:
    a = reshape(arange(float(N*N)), (N,N)) + identity(N)
    # print "typecode of array is: ", a.typecode()
    start = time.time()
    inv_a = inverse(a)
    stop = time.time()

    # print inv_a

    startmult = time.time()
    b =  matrixmultiply(a, inv_a)
    stopmult = time.time()
    residual = b - identity(N)

    #DGEMM(TRANSA, TRANSB, M, N, K, ALPHA,A, LDA, B, LDB, BETA,C,LDC) 
    C = zeros(a.shape, a.typecode())

    alpha = array([1.0])
    beta  = array([1.0])
    startdgemm = time.time()
    dblas.dgemm('N','N', N,N,N, alpha, 
                 a,N,       # A, LDA
                 inv_a, N,  # B, LDB
                 beta,  # BETA
                 C, N)
    stopdgemm = time.time()
    # print C

    print "%d x %d matrix inversion" % (N,N)
    print "max error is:",
    print maximum.reduce(fabs(ravel(residual)))
    print "time elapsed for inverse (sec) %f" % (stop-start)
    print "time elapsed for matrixmultiply (sec) %f" % (stopmult-startmult)
    print "time elapsed for dgemm-matrixmultiply (sec) %f" % (stopdgemm-startdgemm)
    print



More information about the NumPy-Discussion mailing list