Okay. Let's assume that the input shape is `(s1, s2, ..., sa - 1, sa, sa+1, ... , sb - 1, sb, sb + 1, ... sn)` where `a = axis1, b = axis2`. What would be the output shape? It would be `(s1, s2, ..., sa - 1, sa + 1, ..., sb - 1, sb + 1, ... sn, sd)` where `sd` is the size of the diagonal of a `sa x sb` matrix. Note that `s1, s2, ..., sa - 1`, `sa + 1, sa + 2, ..., sb - 1` and `sb + 1, sb + 2, ... sn` remain untouched (directly mapped), and for this reason we can merge them respectively. To remove `sa` and `sb` we can also merge them into segments `sa + 1, sa + 2, ..., sb - 1` and `sb + 1, sb + 2 , ... , sn` respectively, thus reshaping the original input into `(s1 * s2 * ... * sa - 1, sa * sa + 1 * ... * sb - 1, sb * sb + 1 * ... * sn)`, and the corresponding output into `(s1 * s2 * ... * sa - 1, sa + 1 * sa + 2 * ... * sb - 1, sb + 1 * sb + 2 * ... * sn, sd)`. In this way, to get `output(x1, x2, x3, x4)`, we can simply find the index of `input(x1, x2, x3)` plus an offset `x4 * (str ide1 + stride2)` where `stride1` and `stride2` are strides of `axis1` and `axis2` in the input respectively.
[ Full content available at: https://github.com/apache/incubator-mxnet/pull/12430 ] This message was relayed via gitbox.apache.org for [email protected]
