<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/114855>114855</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
[mlir][sparse] mlir-opt crash when lowering softmax with sparse tensors
</td>
</tr>
<tr>
<th>Labels</th>
<td>
mlir
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
vmiheer
</td>
</tr>
</table>
<pre>
Here's the example mlir performing softmax on sparse tensors. The softmax expansion itself is performed by softmax decomposition in (upstream) mlir.
<details>
<summary>
input.mlir
</summary>
```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
module {
func.func @softmax(%arg0: tensor<?x?x?xf32, #sparse>, %arg1: !llvm.ptr) -> tensor<?x?x?xf32, #sparse> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_i8 = arith.constant 1 : i8
%c2 = arith.constant 2 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #sparse>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #sparse>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #sparse>
%0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #sparse>
%c0_2 = arith.constant 0 : index
%dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #sparse>
%c1_4 = arith.constant 1 : index
%dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #sparse>
%c2_6 = arith.constant 2 : index
%dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #sparse>
%1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%cst_8 = arith.constant -3.40282347E+38 : f32
%2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #sparse>) outs(%2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.maxnumf %in, %out : f32
linalg.yield %8 : f32
} -> tensor<?x?xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %3 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
^bb0(%in: f32, %in_10: f32, %out: f32):
%8 = arith.subf %in, %in_10 : f32
%9 = math.exp %8 : f32
linalg.yield %9 : f32
} -> tensor<?x?x?xf32, #sparse>
%cst_9 = arith.constant 0.000000e+00 : f32
%5 = linalg.fill ins(%cst_9 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%4 : tensor<?x?x?xf32, #sparse>) outs(%5 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.addf %in, %out : f32
linalg.yield %8 : f32
} -> tensor<?x?xf32>
%7 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %6 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
^bb0(%in: f32, %in_10: f32, %out: f32):
%8 = arith.divf %in, %in_10 : f32
linalg.yield %8 : f32
} -> tensor<?x?x?xf32, #sparse>
return %7 : tensor<?x?x?xf32, #sparse>
}
}
```
</details>
Commandline: `mlir-opt --sparsifier input.mlir`
Git sha: 33363521ca24f912cc25530f6cecbca53acce8a3
Discourse discussion: https://discourse.llvm.org/t/sparsifier-crash-while-lowering-softmax/82721
Quick reproduction using Compiler Explorer: https://godbolt.org/z/G845EEjMo
Possible resolutions:
1. Add failure in sparsifier for the case specifying features which are not supported.
2. One possible lowering:
<details>
<summary>softmax_sparse</summary>
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
module {
func.func @softmax(%arg0: tensor<?x?x?xf32, #csrv>, %arg1: !llvm.ptr)
-> tensor<?x?x?xf32, #csrv>
// -> tensor<?x?xf32>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_i8 = arith.constant 1 : i8
%c2 = arith.constant 2 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #csrv>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #csrv>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #csrv>
%0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #csrv>
%c0_2 = arith.constant 0 : index
%dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #csrv>
%c1_4 = arith.constant 1 : index
%dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #csrv>
%c2_6 = arith.constant 2 : index
%dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #csrv>
%11 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%minus_inf = arith.constant -3.40282347E+38 : f32
%21 = linalg.fill ins(%minus_inf : f32) outs(%11 : tensor<?x?xf32>) -> tensor<?x?xf32>
%31 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%21 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%res = sparse_tensor.reduce %in, %out, %minus_inf : f32 {
^bb0(%x0: f32, %x1: f32):
%00 = arith.maxnumf %x0, %x1 : f32
sparse_tensor.yield %00: f32
}
linalg.yield %res : f32
} -> tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%arg0 : tensor<?x?x?xf32, #csrv>) {
^bb0(%in: f32, %out: f32):
%x = linalg.index 0: index
%y = linalg.index 1: index
%z = linalg.index 2: index
%result = sparse_tensor.unary %in : f32 to f32
present={
^bb0(%in1: f32):
%maxel = tensor.extract %31[%x, %z]: tensor<?x?xf32>
%8 = arith.subf %in1, %maxel : f32
%ret = math.exp %8 : f32
sparse_tensor.yield %ret : f32
}
absent={}
linalg.yield %result : f32
} -> tensor<?x?x?xf32, #csrv>
%1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%cst_8 = arith.constant 0. : f32
%2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%res = sparse_tensor.reduce %in, %out, %cst_8 : f32 {
^bb0(%x0: f32, %x1: f32):
%00 = arith.addf %x0, %x1 : f32
sparse_tensor.yield %00: f32
}
linalg.yield %res : f32
} -> tensor<?x?xf32>
%5 = linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%3: tensor<?x?x?xf32, #csrv>) {
^bb0(%in: f32):
%x = linalg.index 0: index
%z = linalg.index 2: index
%result = sparse_tensor.unary %in : f32 to f32
present={
^bb0(%in1: f32):
%denom = tensor.extract %4[%x, %z]: tensor<?x?xf32>
%ret = arith.divf %in1, %denom : f32
sparse_tensor.yield %ret : f32
}
absent={}
linalg.yield %result : f32
} -> tensor<?x?x?xf32, #csrv>
// return %3: tensor<?x?x?xf32, #csrv>
return %5:tensor<?x?x?xf32, #csrv>
}
}
```
</details>
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzUWk2T4jwO_jXm4oJKbNIJBw7dQ_PuZWt3q_ZOmUQBv5vYKdvphvn1W3Y-CBC--2uoGZoEWZKlx1L8GKY1XwmAKQpeUDAbsNKspZq-5XwNoAZLmWyn_wAFiIQamzVg2LC8yADnGVe4AJVKlXOxwlqmJmcbLAXWBVMasAGhpdIj_N81tF_DpmBCcykwNxqyFHPdaIEEL7etYAKxzAupuXHCAiMSlYU2CliOyMTZHyFvhrxnRH8lYBjPNKKv7S1d5jlT2_YWF0VpRnZYK4LI_FCqfn_yqn8dcUJzVmBEZ5ilKRewyFnhdESJh8gvnPjunVjnhoi-4r5vdnacOv8-fYea6oBbXe3Vogr_CEQsEy5WVnX4gps5XDCDEX3GCQgNlYy7tglRoDUk1ZCu0ASjcNb6hHEukzIDjMKX5g7GaSnikX3DaOzVeUYkQiRgauVZZZXP1lU63zT_U0qsvXZm1oq7tsN8OwwRP8ve8lFhVDuNa1Xtu4it2tir0qK4WY9iKbRhwuAqJlwksKnlrajfJ-ofibaq_QWPzgyJOqpJnxw5rVqbXrdHnnsBIi9eNQcbhf2xCc_d2Boz7rLOShXpuBp5dX7aWSQ8X3gXdPsP6PYv6Cb36G7Csuc45IXZVnBNeF7rd9PrfPbdSrhrNrG36M34Me46WVvQS3lbPBSB2F-Mr0N47U9wKddO3_3-kMXTbcvCOhVeAolTerdT_mmYLGgHHOFJcFgbx3PVZtFbK4Z0NPZIROg4fEXkhUYHqxqRoEJSxgXLVqOUZxnmQldeNWqrEWSCZWnqr04sxNq7k5W113va9WAFAhSPba11OeJiZdudrlpR8FJ1wzrMti-iYGavuAHFjFQLsy2gI00KpliWQYZInRyiIClj-6ywu9UVCmYonO2CYLN_W8r3AnViTe0CddhUgtfl0qvGctHG3iFDlmaXDESfu-NcJLsYyNlGlHmKKz2tggMA1FHfcsiSRsNx2Q9nt-Rz_Fg-24-3Z_b8nZ7E1nGhNyb417mE7nJ_O252YDiHAy4WvvcANHS53MeFU9iXeTd04obmzKxHsClOogQfgWlyG5iurOvaLCZ3P74E56vd5POr3dOPrnY3tty9KAXfVepYklyqcycQ-hHlLvwzyt24js2Nzy9fUesuouHRgpfwt6sL3oMQuaKIKTClEg107nmWtBvo-qL9eEBHdJmLI8rDvf-Sec5EknEBbmdcURhDWRg8HDqjPOWgcIcNadT-xQ3Wa2aHUUqfaED8mJFxOvFJHJMgoF76FEO8jFlAWRxDxGg1cMZ1LEulASdcx6XWti7RZ7w2ptA2i2Ru_W2kRm6vLtUKkblBZL7zahgrptfD9zXPYJjJd1BcrIYtUzCPSEj8yuZ_Sh7_DysolKwrIS41Fyv8S-YFz0Dh102RSQXq2JOVTJYyM7ULvxGZ_xWNg9fXv_8pu6H8t9SaLzPACrTMSmtDt5j0R_g5SXDKeFYqwLxmvargplI5rixmGrAuIObp1vqWAjOlAo3f1zxeY6YAC2mwLotCKgNJzWeREf6XAFw09ptItMYvEV51xBYNyHqJrgNM_WB6K9bq7QeRW4hQd_9-ly46c8purNUnWu2G4Ev4PJvYS2zeztw1pbnRuCuqbsWf7fxte-lU4h9ICe4okD-OFzzKCv5QanBf_UcSg_2Ofx4t2G_vW5nBEy59Ey94wpvvZAX7XfI_iRTMuSj1gov0HmKwUUL80xvmroG-TfOHcoT-j942X0sStm1sjyK8GKiP3DireuL7TwVuunC4ia4_HSX60KFDpzYHm7WNf8anplR6_eTlxmuVHGB0fwbths3z-oms3Z6pf6NXBeZRNuAhMvsLGM-7UXrH4COU3A9cF9tNN7YuoNjrLdu1_PZY3j8n__tYnpyTV6DLzPSsplIwta0WU7tkjOwgyz48gzCIzi6sJC6uWDo520C210Q2RrHYVIXTQSfY1GH-bUFxXQ9p9fdzx35THWrrPWuujZO5SCGfWMzV0F7VRwsaY7bsRLX9umehV2m7m9Y5eoz86nM9b3SqX3_7ad6Dpz8_ulH_8D69l-FP79EN6_6HNegA3wHQCpVHr8MgfWzjPtS-AyL9nD48cZzKfX33S9tn-3q0j9qyDELmJ5rn-PHe2fS-w4OIpns25k90z09si6dX2Yd2xzoQjlvbnXzcguDeo5MA0eebGb5who_PTc6fmAySKU0mdMIGMPVD6oXjyA-8wXq6XAZ-SknixwkZMwrROGJpGIXxJI2eooQN-JR4ZOz73tifEEK9Ufy0DCNvEtCJH5I4ekJjD3LGs_aUY8C1LmHq--MoCAYZW0Kmp1UBcccvrjAM1NTKD5flSqOxl3Ft9E6D4SZzv9B1A4IZCl5qaj-Y4fZ8xx2e4Pc1iPbEoP017Ts364Pf5Q5KlU0PDka4WZfLUSxzRObWev1nWCj5N8QGkbmbjEZkXs_nbUr-HwAA__98YwFX">