Okay. Let's assume that the input shape is `(s1, s2, ..., sa - 1, sa, sa+1, ... 
, sb - 1, sb, sb + 1, ... sn)` where `a = axis1, b = axis2`. What would be the 
output shape? It would be `(s1, s2, ..., sa - 1, sa + 1, ..., sb - 1, sb + 1, 
... sn, sd)` where `sd` is the size of the diagonal of a `sa x sb` matrix. Note 
that `s1, s2, ..., sa - 1`, `sa + 1, sa + 2, ..., sb - 1` and `sb + 1, sb + 2, 
... sn` remain untouched (directly mapped), and for this reason we can merge 
them respectively. To remove `sa` and `sb` we can also merge them into segments 
`sa + 1, sa + 2, ..., sb - 1` and `sb + 1, sb + 2 , ... , sn` respectively, 
thus reshaping the original input into `(s1 * s2 * ... * sa - 1, sa * sa + 1 * 
... * sb - 1,  sb * sb + 1 * ... * sn)`, and the corresponding output into `(s1 
* s2 * ... * sa - 1, sa + 1 * sa + 2 * ... * sb - 1, sb + 1 * sb + 2 * ... * 
sn, sd)`. In this way, to get `output(x1, x2, x3, x4)`, we can simply find the 
index of `input(x1, x2, x3)` plus an offset `x4 * (str
 ide1 + stride2)` where `stride1` and `stride2` are strides of `axis1` and 
`axis2` in the input respectively.

[ Full content available at: 
https://github.com/apache/incubator-mxnet/pull/12430 ]
This message was relayed via gitbox.apache.org for [email protected]

Reply via email to