<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/61820>61820</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
mlir/Affine: sibling loop fusion with `fusion-maximal` on leads to invalid optimization
</td>
</tr>
<tr>
<th>Labels</th>
<td>
new issue
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
rohany
</td>
</tr>
</table>
<pre>
The following mlir file:
```
func.func @f(%input : memref<10xf32>, %output : memref<10xf32>, %reduc : memref<10xf32>) {
%zero = arith.constant 0. : f32
%one = arith.constant 1. : f32
affine.for %i = 0 to 10 {
%0 = affine.load %input[%i] : memref<10xf32>
%2 = arith.addf %0, %one : f32
affine.store %2, %output[%i] : memref<10xf32>
}
affine.for %i = 0 to 10 {
%0 = affine.load %input[%i] : memref<10xf32>
%1 = affine.load %reduc[0] : memref<10xf32>
%2 = arith.addf %0, %1 : f32
affine.store %2, %reduc[0] : memref<10xf32>
}
return
}
```
run with `bin/mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=sibling fusion-maximal}))'` produces:
```
module {
func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%arg3] : memref<10xf32>
%1 = arith.addf %0, %cst_0 : f32
affine.store %1, %arg1[%arg3] : memref<10xf32>
affine.for %arg4 = 0 to 10 {
%2 = affine.load %arg0[%arg4] : memref<10xf32>
%3 = affine.load %arg2[0] : memref<10xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[0] : memref<10xf32>
}
}
return
}
}
```
This looks incorrect to me, as the reduction into `%arg2[0]` will occur 100 times instead of 10 times now. Without `fusion-maximal`, the correct fusion gets applied:
```
module {
func.func @f(%arg0: memref<10xf32>, %arg1: memref<10xf32>, %arg2: memref<10xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%arg3] : memref<10xf32>
%1 = arith.addf %0, %cst_0 : f32
affine.store %1, %arg1[%arg3] : memref<10xf32>
%2 = affine.load %arg0[%arg3] : memref<10xf32>
%3 = affine.load %arg2[0] : memref<10xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %arg2[0] : memref<10xf32>
}
return
}
}
```
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzkV92u4jYQfprhxiJyJj_ARS4OS3mClXpZmcRJ3Dp2ZDvL7j59ZYcs4RDOYdsjVeqiKICZn29m_H04zFrRKM4LyPaQHVZscK02hdEtU99WJ119Kz63nNRaSn0WqiGdFIbUQnJIXoAegL5ATi9X-FoPqoz8jUBKa8AtYCZUPzgCyQvpeGd4DcmnmH6tE4TkN8BPBDDTg3vXxvBqKB-a7Ahs9iMG4q2_c6MJJAfCjHBtVGplHVOO0CiE8G5XY634km18Z8vqWige1dp4NxGcKHGaxHSePwSlY8jRQ2pWkakXvtuYCcgOj8qZx8EZNFZVdYg99S0Av4H4A6R12vAQ4KbJTyeHzeHy8b-oOl6KE7YAZHv6Aa2Ln2_cT6S9No0Y7gajLiyZll_RxQyKnIVrCeT0JBTg0TNsrXtH1j2zdt2LnkuhOCQHwM1pENIJFXW6GiQH3P6gG-B2RL-WWvfrerBCK9jsO115XytO0vN3XF937KvomPSocBeuDeSU9EZXQ8ntI3KPaecjX6I7Mw19k8jMNPF7BvgUz8OQS-uWiU7DiwPuKb0f9uj5B10m_tu-t4Rgpkkec-IxK0KjAil8hGe29JwZi5t6qmiO9_Wujudj-Mn0d4WnbxU-5-Abpafv5gbMkgdh8Fk5uMBJF7o38Ty5n_SSLKTzbfp09qsyzDTiViVmRo8EY7x_boUlUuu_LBGq1Mbw0vkRdNxDY5a4lpOgWk5oRYRy2ivMLWbP-LOQkuiyHAyJKSVOdNyHtI6ziujaz3RcU_ockd-Fa7X_m87pKyHJwwb0WSc0owFpuLOE9b0UvPr_y8ovoymLvPg3wvKUULwf62OE4p_IxEeIxFUYnlaFVVUk1S7ZsRUv4nxLY9wlOa7aIqb1hmJebrJdTGlZ5dvTNs6rbcVPScqQrUSBFBOa4A4xySlGlNf1htYcy2yXbncppJR3TMhIyi9dpE2zEtYOvMjjLdKVZCcubXhwQFT8TMKPgOifI0zhfdanobGQUimss9coTjjJC3_OATy-hLb5xkxHFH9-mdRjOhrdqQ3RikjOKuvZIdQXJkVFdO9EJ74zr3mrwciida4Phxk8Ah4b4drhFJW6Azx6NJe3dW_0n7x0gMdQgwU8hhr_DgAA__-Ug5Kl">