Attached is a patch that: 1) removes the .expand() call that was being done in matrix_multiply() 2) adds a doctest and improves the docstring 3) provides and alternative implementation that is somewhat clearer, but about 5% slower.
~Luke --~--~---------~--~----~------------~-------~--~----~ You received this message because you are subscribed to the Google Groups "sympy-patches" group. To post to this group, send email to sympy-patches@googlegroups.com To unsubscribe from this group, send email to sympy-patches+unsubscr...@googlegroups.com For more options, visit this group at http://groups.google.com/group/sympy-patches?hl=en -~----------~----~----~----~------~----~------~--~---
From 994e7239dae0f2f7aab4cae57fdfdccb561edde1 Mon Sep 17 00:00:00 2001 From: Luke Peterson <hazelnu...@gmail.com> Date: Sat, 8 Aug 2009 21:03:31 -0700 Subject: [PATCH] Removed .expand() in matrix_multiply, added doctests, alternate implementation --- sympy/matrices/matrices.py | 44 +++++++++++++++++++++++++++++++++++++------- 1 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sympy/matrices/matrices.py b/sympy/matrices/matrices.py index 72bb129..ae04a32 100644 --- a/sympy/matrices/matrices.py +++ b/sympy/matrices/matrices.py @@ -1474,20 +1474,50 @@ class Matrix(object): count += 1 return v -def matrix_multiply(A,B): +def matrix_multiply(A, B): """ - Return A*B. + Matrix product A*B. + + A and B must be of appropriate dimensions. If A is a m x k matrix, and B + is a k x n matrix, the product will be an m x n matrix. + + Example: + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> A*B + [30, 36, 42] + [66, 81, 96] + >>> B*A + Traceback (most recent call last): + ... + ShapeError + >>> """ + # The following implmentation is equivalent, but about 5% slower + #ma, na = A.shape + #mb, nb = B.shape + # + #if na != mb: + # raise ShapeError() + #product = Matrix(ma, nb, lambda i,j: 0) + #for i in xrange(ma): + # for j in xrange(nb): + # s = 0 + # for k in range(na): + # s += A[i, k]*B[k, j] + # product[i, j] = s + #return product if A.shape[1] != B.shape[0]: raise ShapeError() blst = B.T.tolist() alst = A.tolist() - return Matrix(A.shape[0], B.shape[1], lambda i,j: - reduce(lambda k,l: k+l, - map(lambda n,m: n*m, + return Matrix(A.shape[0], B.shape[1], lambda i, j: + reduce(lambda k, l: k+l, + map(lambda n, m: n*m, alst[i], - blst[j])).expand()) - # .expand() is a test + blst[j]))) def matrix_add(A,B): """Return A+B""" -- 1.6.0.4