Issue 114855
Summary [mlir][sparse] mlir-opt crash when lowering softmax with sparse tensors
Labels mlir
Assignees
Reporter vmiheer
    Here's the example mlir performing softmax on sparse tensors. The softmax expansion itself is performed by softmax decomposition in (upstream) mlir.
<details>
<summary>
input.mlir
</summary>

```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
  module {
    func.func @softmax(%arg0: tensor<?x?x?xf32, #sparse>, %arg1: !llvm.ptr) -> tensor<?x?x?xf32, #sparse> {
      %c0 = arith.constant 0 : index
 %c1 = arith.constant 1 : index
      %c1_i8 = arith.constant 1 : i8
 %c2 = arith.constant 2 : index
      %cst = arith.constant 0.000000e+00 : f32
      %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #sparse>
 %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #sparse>
 %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #sparse>
      %0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #sparse>
 %c0_2 = arith.constant 0 : index
      %dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #sparse>
      %c1_4 = arith.constant 1 : index
 %dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #sparse>
      %c2_6 = arith.constant 2 : index
      %dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #sparse>
      %1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
      %cst_8 = arith.constant -3.40282347E+38 : f32
 %2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
      %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #sparse>) outs(%2 : tensor<?x?xf32>) {
      ^bb0(%in: f32, %out: f32):
        %8 = arith.maxnumf %in, %out : f32
 linalg.yield %8 : f32
      } -> tensor<?x?xf32>
      %4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %3 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
 ^bb0(%in: f32, %in_10: f32, %out: f32):
        %8 = arith.subf %in, %in_10 : f32
        %9 = math.exp %8 : f32
        linalg.yield %9 : f32
      } -> tensor<?x?x?xf32, #sparse>
      %cst_9 = arith.constant 0.000000e+00 : f32
      %5 = linalg.fill ins(%cst_9 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
      %6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%4 : tensor<?x?x?xf32, #sparse>) outs(%5 : tensor<?x?xf32>) {
      ^bb0(%in: f32, %out: f32):
        %8 = arith.addf %in, %out : f32
        linalg.yield %8 : f32
      } -> tensor<?x?xf32>
      %7 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %6 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
      ^bb0(%in: f32, %in_10: f32, %out: f32):
        %8 = arith.divf %in, %in_10 : f32
        linalg.yield %8 : f32
      } -> tensor<?x?x?xf32, #sparse>
      return %7 : tensor<?x?x?xf32, #sparse>
    }
  }

```
</details>

Commandline: `mlir-opt --sparsifier input.mlir`
Git sha: 33363521ca24f912cc25530f6cecbca53acce8a3
Discourse discussion: https://discourse.llvm.org/t/sparsifier-crash-while-lowering-softmax/82721
Quick reproduction using Compiler Explorer: https://godbolt.org/z/G845EEjMo

Possible resolutions:
1. Add failure in sparsifier for the case specifying features which are not supported.
2. One possible lowering:
<details>
<summary>softmax_sparse</summary>
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
  module {
    func.func @softmax(%arg0: tensor<?x?x?xf32, #csrv>, %arg1: !llvm.ptr)
    -> tensor<?x?x?xf32, #csrv>
    // -> tensor<?x?xf32>        
    {
      %c0 = arith.constant 0 : index
 %c1 = arith.constant 1 : index
      %c1_i8 = arith.constant 1 : i8
      %c2 = arith.constant 2 : index
      %cst = arith.constant 0.000000e+00 : f32
      %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #csrv>
      %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #csrv>
 %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #csrv>
      %0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #csrv>
      %c0_2 = arith.constant 0 : index
      %dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #csrv>
      %c1_4 = arith.constant 1 : index
 %dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #csrv>
      %c2_6 = arith.constant 2 : index
      %dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #csrv>
      %11 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
      %minus_inf = arith.constant -3.40282347E+38 : f32
      %21 = linalg.fill ins(%minus_inf : f32) outs(%11 : tensor<?x?xf32>) -> tensor<?x?xf32>
      %31 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%21 : tensor<?x?xf32>) {
      ^bb0(%in: f32, %out: f32):
        %res = sparse_tensor.reduce %in, %out, %minus_inf : f32 {
          ^bb0(%x0: f32, %x1: f32):
            %00 = arith.maxnumf %x0, %x1 : f32
 sparse_tensor.yield %00: f32
        }
        linalg.yield %res : f32
      } -> tensor<?x?xf32>
      %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%arg0 : tensor<?x?x?xf32, #csrv>) {
        ^bb0(%in: f32, %out: f32):
          %x = linalg.index 0: index
          %y = linalg.index 1: index
          %z = linalg.index 2: index
          %result = sparse_tensor.unary %in : f32 to f32
 present={
          ^bb0(%in1: f32):
            %maxel = tensor.extract %31[%x, %z]: tensor<?x?xf32>
            %8 = arith.subf %in1, %maxel : f32
            %ret = math.exp %8 : f32
 sparse_tensor.yield %ret : f32
          }
          absent={}
 linalg.yield %result : f32
      } -> tensor<?x?x?xf32, #csrv>
 %1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
      %cst_8 = arith.constant 0. : f32
      %2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
      %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%2 : tensor<?x?xf32>) {
      ^bb0(%in: f32, %out: f32):
        %res = sparse_tensor.reduce %in, %out, %cst_8 : f32 {
          ^bb0(%x0: f32, %x1: f32):
            %00 = arith.addf %x0, %x1 : f32
 sparse_tensor.yield %00: f32
        }
        linalg.yield %res : f32
      } -> tensor<?x?xf32>
      %5  = linalg.generic {indexing_maps = [#map],                  
          iterator_types = ["parallel", "parallel", "parallel"]}
          outs(%3: tensor<?x?x?xf32, #csrv>) {
        ^bb0(%in: f32): 
          %x = linalg.index 0: index
 %z = linalg.index 2: index
          %result = sparse_tensor.unary %in : f32 to f32
          present={
          ^bb0(%in1: f32):
 %denom = tensor.extract %4[%x, %z]: tensor<?x?xf32>
            %ret = arith.divf %in1, %denom : f32
            sparse_tensor.yield %ret : f32
          }
          absent={}
        linalg.yield %result : f32
      } -> tensor<?x?x?xf32, #csrv>
      // return %3: tensor<?x?x?xf32, #csrv>
      return %5:tensor<?x?x?xf32, #csrv>
    } 
  }
```
</details>
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to