Hi, it's nice to see strassen has attracted attention again. I would like to
know which hardware have you used and how many cores have you used?
Actually, it's easy to implement strassen in TVM, and I have tested this
algorithm with two different implementations.
TE version:
```python
def strassen_gemm(N):
def gemm(A, B, N, name=""):
global GEMM_COUNT
if name != "":
name += "G%d_" % GEMM_COUNT
GEMM_COUNT += 1
if (N > DIRECT_SIZE):
return strassen(A, B, N, name)
else:
return direct(A, B, N, name)
def direct(A, B, N, name):
k = tvm.reduce_axis((0, N))
C = tvm.compute(A.shape, lambda i, j: tvm.sum(A[i][k] * B[k][j],
axis=k),
name=name+'C')
return C
def split(A, new_n, ori_name="Matrix"):
A11 = tvm.compute((new_n, new_n),
lambda i, j: A[i][j], name=ori_name+"11")
A12 = tvm.compute((new_n, new_n),
lambda i, j: A[i][j+new_n], name=ori_name+"12")
A21 = tvm.compute((new_n, new_n),
lambda i, j: A[i+new_n][j], name=ori_name+"21")
A22 = tvm.compute((new_n, new_n),
lambda i, j: A[i+new_n][j+new_n], name=ori_name+"22")
return A11, A12, A21, A22
def sub(A, B, N, name):
C = tvm.compute((N, N),
lambda i, j: A[i][j] - B[i][j], name=name)
return C
def add(A, B, N, name):
C = tvm.compute((N, N),
lambda i, j: A[i][j] + B[i][j], name=name)
return C
def strassen(A, B, N, name):
global GEMM_LEVEL
new_n = int(N / 2)
A11, A12, A21, A22 = split(A, new_n, name+"A")
B11, B12, B21, B22 = split(B, new_n, name+"B")
S1 = sub(B12, B22, new_n, name+"S1")
S2 = add(A11, A12, new_n, name+"S2")
S3 = add(A21, A22, new_n, name+"S3")
S4 = sub(B21, B11, new_n, name+"S4")
S5 = add(A11, A22, new_n, name+"S5")
S6 = add(B11, B22, new_n, name+"S6")
S7 = sub(A12, A22, new_n, name+"S7")
S8 = add(B21, B22, new_n, name+"S8")
S9 = sub(A11, A21, new_n, name+"S9")
S10 = add(B11, B12, new_n, name+"S10")
level = GEMM_LEVEL
GEMM_LEVEL += 1
P1 = gemm(A11, S1, new_n, name+"L%d_"%level)
P2 = gemm(S2, B22, new_n, name+"L%d_"%level)
P3 = gemm(S3, B11, new_n, name+"L%d_"%level)
P4 = gemm(A22, S4, new_n, name+"L%d_"%level)
P5 = gemm(S5, S6, new_n, name+"L%d_"%level)
P6 = gemm(S7, S8, new_n, name+"L%d_"%level)
P7 = gemm(S9, S10, new_n, name+"L%d_"%level)
C11 = tvm.compute((new_n, new_n),
lambda i, j: P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j],
name=name+"C11")
C12 = add(P1, P2, new_n, name+"C12")
C21 = add(P3, P4, new_n, name+"C21")
C22 = tvm.compute((new_n, new_n),
lambda i, j: P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j],
name=name+"C22")
C = tvm.compute((N, N),
lambda i, j: tvm.if_then_else(i < new_n,
tvm.if_then_else(j < new_n, C11[i][j], C12[i][j-new_n]),
tvm.if_then_else(j < new_n, C21[i-new_n][j],
C22[i-new_n][j-new_n])),
name=name+"C")
return C
A = tvm.placeholder((N, N), name="A")
B = tvm.placeholder((N, N), name="B")
C = gemm(A, B, N)
sch = tvm.create_schedule(C.op)
return sch, [A, B, C]
```
Relay Version(I even tried an implementation of merging the gemm of 7
sub-matrix to a single batch_matmul):
```python
def strassen_gemm(N, K, M, max_level=1):
# A [N, K]
# B [K, M]
# C [N, M]
def gemm(A, B, N, K, M, level):
if (level < max_level and N % 2 == 0 and
K % 2 == 0 and M % 2 == 0):
return strassen(A, B, N, K, M, level)
else:
return direct(A, B, N, K, M)
def direct(A, B, N, K, M):
C = relay.nn.dense(A, relay.transpose(B, [1, 0]))
return C
def split(A, new_x, new_y):
A11 = relay.strided_slice(A, [0, 0], [new_x, new_y])
A12 = relay.strided_slice(A, [0, new_y], [new_x, new_y*2])
A21 = relay.strided_slice(A, [new_x, 0], [new_x*2, new_y])
A22 = relay.strided_slice(A, [new_x, new_y], [new_x*2, new_y*2])
return A11, A12, A21, A22
def strassen(A, B, N, K, M, level):
new_n = int(N / 2)
new_k = int(K / 2)
new_m = int(M / 2)
A11, A12, A21, A22 = split(A, new_n, new_k)
B11, B12, B21, B22 = split(B, new_k, new_m)
S1 = B12 - B22
P1 = gemm(A11, S1, new_n, new_k, new_m, level+1)
S2 = A11 + A12
P2 = gemm(S2, B22, new_n, new_k, new_m, level+1)
C12 = P1 + P2
S3 = A21 + A22
P3 = gemm(S3, B11, new_n, new_k, new_m, level+1)
S4 = B21 - B11
P4 = gemm(A22, S4, new_n, new_k, new_m, level+1)
C21 = P3 + P4
S5 = A11 + A22
S6 = B11 + B22
P5 = gemm(S5, S6, new_n, new_k, new_m, level+1)
S7 = A12 - A22
S8 = B21 + B22
P6 = gemm(S7, S8, new_n, new_k, new_m, level+1)
C11 = P5 + P4 - P2 + P6
S9 = A11 - A21
S10 = B11 + B12
P7 = gemm(S9, S10, new_n, new_k, new_m, level+1)
C22 = P5 + P1 - P3 - P7
C1 = relay.concatenate([C11, C12], 1)
C2 = relay.concatenate([C21, C22], 1)
C = relay.concatenate([C1, C2], 0)
return C
def strassen_merge(A, B, N):
new_n = int(N / 2)
A11, A12, A21, A22 = split(A, new_n)
B11, B12, B21, B22 = split(B, new_n)
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
if new_n > direct_size:
P1 = gemm(A11, S1, new_n)
P2 = gemm(S2, B22, new_n)
P3 = gemm(S3, B11, new_n)
P4 = gemm(A22, S4, new_n)
P5 = gemm(S5, S6, new_n)
P6 = gemm(S7, S8, new_n)
P7 = gemm(S9, S10, new_n)
else:
Merge_A = []
for a in [A11, S2, S3, A22, S5, S7, S9]:
Merge_A.append(relay.expand_dims(a, 0))
Merge_A = relay.concatenate(Merge_A, 0)
Merge_B = []
for b in [S1, B22, B11, S4, S6, S8, S10]:
Merge_B.append(relay.expand_dims(b, 0))
Merge_B = relay.concatenate(Merge_B, 0)
Merge_C = relay.nn.batch_matmul(Merge_A, relay.transpose(Merge_B,
[0, 2, 1]))
ss = relay.split(Merge_C, 7)
P1 = relay.reshape(ss[0], [new_n, new_n])
P2 = relay.reshape(ss[1], [new_n, new_n])
P3 = relay.reshape(ss[2], [new_n, new_n])
P4 = relay.reshape(ss[3], [new_n, new_n])
P5 = relay.reshape(ss[4], [new_n, new_n])
P6 = relay.reshape(ss[5], [new_n, new_n])
P7 = relay.reshape(ss[6], [new_n, new_n])
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
C1 = relay.concatenate([C11, C12], 1)
C2 = relay.concatenate([C21, C22], 1)
C = relay.concatenate([C1, C2], 0)
return C
A = relay.var("A", shape=(N, K))
B = relay.var("B", shape=(K, M))
C = gemm(A, B, N, K, M, 0)
return A, B, C
```
The evaluation performance is not so good in the end. Only in a 4 cores
`1024*1024*1024` case with `direct_size = 512`, I get better performance with
strassen.
I think there are several reasons for this:
1. The TE version contains too much stages, which makes it hard to schedule,
even we have the auto_schedule tool Ansor.
2. The Relay version contains some unnatural `slice` and `concat`, which are
not so friendly for the memory access.
3. Op trends to perform better in gemm with a larger size. When we split a
single gemm to 7 sub-matrix, these gemm with smaller size are likely to perform
lower GFlops.
4. MNN manages it's memory access and compute threads well. It can even run the
7 sub-matrix gemm in parallel, while TVM cannot support inter_op parallelism.
5. For the strassen algorithm itself, in my understanding it does save
computation in single thread running(in theory can reduce from O(3) to O(2.7)),
but when we take it to a multi-thread situation I think it will not be so
beneficial.
So my conclusion is:
1. Strassen should be more powerful with little CPU cores, e.g. in a ARM CPU
with only 4 or 8 cores, which is just the target hardware of MNN. In a Intel
CPU with more cores, I don't think we can benefit from strassen.
2. MNN does have better memory/thread management since it's directly written in
C. TVM seems not able to do the same thing with codegen.
---
[Visit
Topic](https://discuss.tvm.apache.org/t/strassen-algorithm-for-dense/2661/8) to
respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/b2fd1f062dc099322dcbde547d2033398a33f5082b0d1520c00b12c2a2dcc2df).