<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/95230>95230</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
Unexpected Behavior on Affine-Loop-Fusion
</td>
</tr>
<tr>
<th>Labels</th>
<td>
new issue
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
sgjzfzzf
</td>
</tr>
</table>
<pre>
Hi, I'm developing a compiler for ONNX based on MLIR. I'm trying to optimize the code generation with the Affine passes, but I need help. Some error occurs in the LayerNormalization operator. Here is the code generated by my compiler automatically.
```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%cst = arith.constant dense<1.000000e+00> : tensor<768xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%c0_1 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant 1.000000e+00 : f32
%view_4 = memref.view %view[%c0_1][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst : memref<768xf32>
%1 = bufferization.to_memref %cst_0 : memref<768xf32>
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %out : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %view_4 : memref<1x128x768xf32>, memref<1x128x1xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%2 = arith.subf %in, %in_5 : f32
linalg.yield %2 : f32
}
linalg.fill ins(%cst_2 : f32) outs(%view_4 : memref<1x128x1xf32>)
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %in : f32
%3 = arith.addf %out, %2 : f32
linalg.yield %3 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %view_4, %0, %1 : memref<1x128x768xf32>, memref<1x128x1xf32>, memref<768xf32>, memref<768xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
%cst_8 = arith.constant 9.99999996E-13 : f32
%2 = arith.addf %in_5, %cst_8 : f32
%3 = math.sqrt %2 : f32
%4 = arith.divf %cst_3, %3 : f32
%5 = arith.mulf %in, %4 : f32
%6 = arith.mulf %5, %in_6 : f32
%7 = arith.addf %6, %in_7 : f32
linalg.yield %7 : f32
}
return
}
}
```
Then, I use `mlir-opt-18 -convert-linalg-to-affine-loops <filename>` to lower it to the Affine dialect.
```mlir
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%cst = arith.constant 9.99999996E-13 : f32
%cst_0 = arith.constant 7.680000e+02 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_4 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%view_5 = memref.view %view[%c0][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_4 : memref<768xf32>
%1 = bufferization.to_memref %cst_3 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.subf %2, %3 : f32
affine.store %4, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
affine.store %cst_2, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.mulf %2, %2 : f32
%5 = arith.addf %3, %4 : f32
affine.store %5, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}
```
After that, I decided to use the Affine-Loop-Fusion pass to optimize it with `mlir-opt-18 -convert-linalg-to-affine-loops -affine-loop-fusion=mode=greedy <filename>`.
```mlir
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c0_0 = arith.constant 0 : index
%c0_1 = arith.constant 0 : index
%c0_2 = arith.constant 0 : index
%c0_3 = arith.constant 0 : index
%c0_4 = arith.constant 0 : index
%c0_5 = arith.constant 0 : index
%c0_6 = arith.constant 0 : index
%c0_7 = arith.constant 0 : index
%cst = arith.constant 9.99999996E-13 : f32
%cst_8 = arith.constant 7.680000e+02 : f32
%cst_9 = arith.constant 1.000000e+00 : f32
%cst_10 = arith.constant 0.000000e+00 : f32
%cst_11 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_12 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0_13 = arith.constant 0 : index
%view = memref.view %arg2[%c0_13][] : memref<512xi8> to memref<512xi8>
%view_14 = memref.view %view[%c0_13][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_12 : memref<768xf32>
%1 = bufferization.to_memref %cst_11 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.addf %4, %5 : f32
affine.store %6, %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
}
%2 = affine.load %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_8 : f32
affine.store %3, %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_2, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.subf %4, %5 : f32
affine.store %6, %arg0[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
}
affine.store %cst_10, %view_14[%c0_4, %arg4, %c0_3] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg0[%c0_5, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.mulf %4, %4 : f32
%7 = arith.addf %5, %6 : f32
affine.store %7, %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
}
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}
```
Please take a look at the last affine-for loop. The comparison follows:
```mlir
// ...
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
// ...
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
```
The pass fuses the `divf` instruction into the loop error. In the `1-128` loop, the `divf` should be executed only once, but due to the wrong fusion, it will be executed 768 times instead. I also examined it in the real example, and the output changes after the pass, as we see in the code.
Could you please provide me with some information on this issue? Is it a bug, or is something wrong with my code and optimization? Thank you so much for your reading!
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzsXF9vozoW_zT0xQrCJuTPQx467VS30t3Z1e6stG-RAyeJ7xictU3b9NOvbKABYhLSpJ3ObEejNBCf4-Pzz8c_G6hSbJUBzLzoixfdXtFcr4WcqdVfz8vn5-XVQiTb2R_MIzfo3iPjFCXwAFxsWLZCFMUi3TAOEi2FRH__9u0_aEEVJEhk6G9_3v_TL2m03Jr2WiCx0Sxlz4D0GlAsEkAryEBSzUSGHple2x-ul0uWAdpQpUCZrhe5RvcoA0jQGvjGR_8SKSCQUkgk4jiXCrHMkv5JtyC_CZlSzp4LtmJjOhDSR3-ABMTUXueQoMUWpdvdeGiuRUo1iynnW98Lbr3guvwcBcX_lDNZ3iJhSjfIC28RtZLPU7rxwhuPTJLAiJ9g-0k8MkUDL_yKXL944dc6O3wGv2CfHTmVXUOkVCQ5B-SNvxTXaJlnsW8-kDcMuNH5PKsr3SMTj0RUrgIvvEYppBKWXniDnzCZPI1Hk6dlSAx3coOKdrhnO9JoF2HyxCa2wRRRrSVb5BqUEZTzh9SHlOl5PGeZBrmkMXjj29ogEDJM46BQjWR67cciU5pmGpmb14hlCTw1Wj8weLTtCxn84roULfpi-XnRbRFNyC2sCQTHEBpCKe2SKoFMgdGPH9h_4JEvQWBNFl4jDZkS0gtvaoprMZ07B1uxDV7HNpjjU1RoxCBOgkb3ltp016YNXbT4OK0x1XzoMp75Wxlvjk8yX-Gp2KWXQtWLfLkEWYaFr8W8IN0ZudZFh4JxD0bz4DgrzjLKV75Neiw2gWCNw7KVyQjK9mLVYBJGEXI2ExmFkBvEdJFH53q7gVprsqGScg7cI6SgctyRkORxkRmI4Te-RSxTL2kCoyPxP0Ui12X7F0PuUeBa-2acR18Xi6AgZ1nlHUVWEbl-uTH1wusdldV-3VVpkliFs2xHu-9rL5reMuBJxaPtkOPb8uKVZrmEVep3Wkbpp-SfZRTj8ZErD4z90eQlDzjUvmfShD00TVqxPs-o54Tbe5nYTrvFoLut15yEe7mCmfaPx3NvV2DZPHp1xKp80TSv4dYy3CsNu2Sc75RZzWqVWKeGxi-RpnuY9fUZ4UI5Os152-Ite3okCh1J3fRRUDjzRttLws-c_pnTT8_pxPn17TN90Mj05UV182j5dSDz3ziLzq7b7zpPsGw-2rszPtkpJy6nnPrT4t_o6wA7csGB2nEe1dxy0kFaZKiUmjnsv1J3piWPREOH69uFUtlNl3TRoaw57CAaOYiimr47yMYOTYxqRukTnI5Gu-CUoHOZVb-83N99qXCbOpjzfQ12vPcoV4BKUGcgNnqAJ2gQi-wBpB4Ucgy0GBT4yYALYWP9Zsk4ZDQF46qjwCwMuXgEiZg232tQVsIoh1gfQ5L2cJbfCmhxYxrHAukQcHFsZihpnfBED8TgTZCKM8EWpefDN4CGPiIIZqeq6Bhg855wTbsYOgewCY-zKtKNvxSy1G7hUDbR4GZ8ORoPa43JpN3cQRDtCMYjB0F9SitouaDJS16x9jBC7pLIcPc12rPQXu7Z7yx0dVa4xaHurFMcLFv3uxo6JijSMYGWAiktJFjaem11EckaRWerBm1c9HWRi3pHnXGXSxzXxQGnaGujuWCrahxSK6McSbgxjhdbhb1t1V8-1LbS-fb6rUL67SO5FcAV6tMVwJ3eURtdcOGE1i-k38g_-kZzp2JsKfQeofPBlHN61Fzab37mRFgt8MghYAztryarCTTsWEy2vSv6-BPo5ROyM-4-p9PPnPBhc0Kzq92g9sU_KHjk4oZfy210IO9UO-wN7H-8j6w5UauS_6QDXau2zpzQVUk7PZBLncBfSYcDB-H0EDJXEWKHNnAF87Yh8HYwY_wG5U_vTPwqBO96qUEivaa6wPESiFkCiQnTXEENghv8KcRmcJcrJjJ7sqxxHI3p4vTZiRBg_WqwtLy98DYVCXjh7UoCJNt9mPD_GwQ8CWiKAzfud6B9z6NIRWM3uNfN3AnodTF3wnQHmDv3vg60H53YfnziIa0zwFrnjklPsHZ6BliL3d7Sk9jpOufDtdjpZefitXPc0xn7Q7ZzHF4KtcU9ztmd1tm5yC0ml4NuMf41sVtnGbcDeuLgcsWps8YrXaO7t8otTixPXSXgsGfNM6ovYC4rW6vyObK42_U9x87eT6vbeyzyOirQg0u8C0v5Rn78AlW9HrDs48Cubk5wk5bnVqBll-eiLgON9qr28zXQ5cIuXBAHHR4ydKqnp3Yu6xo7xUTvm-Rc_V0gzVWLwmGPRaHr7EUlVXvl27bvuMO0542qy60uYey2pGOnpKPT1R850unweDrdj5moQ6cXk3T0njDU-KgRunCoUxLl5KII1PSiCFRj1VEF2PgQBFUuNRoYFO7EXM2PruONdrVUneDDh6hDR-aoZg18CNTHrt0AHPaBsFz7Abgy_vQYEhX9ckjUPzhQBUjTH4Ao4kL8QFRbAIpTpcsRDkyG40JsfPTdPoSYbqhkSmRoKTgXj-rlBKIDGEIeufPIHfJ9_yOsM37a_kW9s8-NjM-NjM-NjM-NjI-ykVH_PCU9n5yZXXPBsYz9q1bS71JAv0fd_D7lsiPoP0zdfHq57MgKv1vdfHK5jFrR_rZ1M2onvI7U1_GERbHbuswVFG_e8EaBsYQ3ChDLlJbFA4eIZeUjE6Y8Lt7t4aP7rCLBA0wmhsb8bAbRYqXWIucJWgCCJ4hzbd9CwrdIZDF45StEkhyq5zIepchWqNyytU9eoUfGeYPepEvNUlBWTqCJj-4R5UogeKIpyyAxZOWLRyRQbu9vuO2PZom9L3K9yTWK1zRbgUK03K8utGIbKvQISAFUnGKRQGOP-MaObCtytCnWGBspHlgCKIVi11qJ1FAvhUzLV54YTkwhplQOXniH7pURlaJFvjJ9ComYsmR6zbJVqQ3Ly74FJQErf7k9Xuwxh3fo-5pmP6wgSqA0j9f2vS9bkUsz_IRlK4_gq2QWJtNwSq9ghsd4gqdDTMKr9QyPk4DCIhgtx4toHNJpmASUTMJlAsFiGA6v2IwEZBiMMMEkmAxDfxpPxzSMFjBZhkkIY28YQEoZ9-12s5CrKzvA2TQiYXDF6QK4mhUPzGXwWI7ePg13JWeGZrDIV8obBpwprXZcNNMcZv_O4GkDsbH8F1jTByakUeT-8YGrXPLZWuuNXa3ZeXjF9Dpf-LFIPXJn-JZ_Bhsp_oJYe-TOSqM8cldI-zAj_wsAAP__9-DwSw">