On Feb 22, 2014 2:03 PM, "Nathaniel Smith" <n...@pobox.com> wrote: > > Hi all, > > Currently numpy's 'dot' acts a bit weird for ndim>2 or ndim<1. In > practice this doesn't usually matter much, because these are very > rarely used. But, I would like to nail down the behaviour so we can > say something precise in the matrix multiplication PEP. So here's one > proposal. > > # CURRENT: > > dot(0d, any) -> scalar multiplication > dot(any, 0d) -> scalar multiplication > dot(1d, 1d) -> inner product > dot(2d, 1d) -> treat 1d as column matrix, matrix-multiply, then > discard added axis > dot(1d, 2d) -> treat 1d as row matrix, matrix-multiply, then discard added axis > dot(2d, 2d) -> matrix multiply > dot(2-or-more d, 2-or-more d) -> a complicated outer product thing: > Specifically, if the inputs have shapes (r, n, m), (s, m, k), then > numpy returns an array with shape (r, s, n, k), created like: > for i in range(r): > for j in range(s): > output[i, j, :, :] = np.dot(input1[i, :, :], input2[j, :, :]) > > # PROPOSED: > > General rule: given dot on shape1, shape2, we try to match these > shapes against two templates like > (..., n?, m) and (..., m, k?) > where ... indicates zero or more dimensions, and ? indicates an > optional axis. ? axes are always matched before ... axes, so for an > input with ndim>=2, the ? axis is always matched. An unmatched ? axis > is treated as having size 1. > > Next, the ... axes are broadcast against each other in the usual way > (prepending 1s to make lengths the same, requiring corresponding > entries to either match or have the value 1). And then the actual > computations are performed using the usual broadcasting rules. > > Finally, we return an output with shape (..., n?, k?). Here "..." > indicates the result of broadcasting the input ...'s against each > other. And, n? and k? mean: "either the value taken from the input > shape, if the corresponding entry was matched -- but if no match was > made, then we leave this entry out." The idea is that just as a column > vector on the right is "m x 1", a 1d vector on the right is treated as > "m x <nothing>". For purposes of actually computing the product, > <nothing> acts like 1, as mentioned above. But it makes a difference > in what we return: in each of these cases we copy the input shape into > the output, so we can get an output with shape (n, <nothing>), or > (<nothing>, k), or (<nothing>, <nothing>), which work out to be (n,), > (k,) and (), respectively. This gives a (somewhat) intuitive principle > for why dot(1d, 1d), dot(1d, 2d), dot(2d, 1d) are handled the way they > are, and a general template for extending this behaviour to other > operations like gufunc 'solve'. > > Anyway, the end result of this is that the PROPOSED behaviour differs > from the current behaviour in the following ways: > - passing 0d arrays to 'dot' becomes an error. (This in particular is > an important thing to know, because if core Python adds an operator > for 'dot', then we must decide what it should do for Python scalars, > which are logically 0d.) > - ndim>2 arrays are now handled by aligning and broadcasting the extra > axes, instead of taking an outer product. So dot((r, m, n), (r, n, k)) > returns (r, m, k), not (r, r, m, k). > > Comments?
The proposed behavior for ndim > 2 is what matrix_multiply (is it still in umath_tests?) does. The nice thing of the proposed new behavior is that the old behavior is easy to reproduce by fooling a little around with the shape of the first argument, while the opposite is not true. Jaime > > -- > Nathaniel J. Smith > Postdoctoral researcher - Informatics - University of Edinburgh > http://vorpus.org > _______________________________________________ > NumPy-Discussion mailing list > NumPy-Discussion@scipy.org > http://mail.scipy.org/mailman/listinfo/numpy-discussion
_______________________________________________ NumPy-Discussion mailing list NumPy-Discussion@scipy.org http://mail.scipy.org/mailman/listinfo/numpy-discussion