[Mlir-commits] [mlir] [mlir][scf] Considering affine.apply when fusing scf::ParallelOp (PR #80145)
Ivan Butygin
llvmlistbot at llvm.org
Wed Jan 31 09:09:53 PST 2024
================
@@ -480,3 +480,49 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @fuse_same_indices_by_affine_apply(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %sum = memref.alloc() : memref<2x3xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+ memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
+ %sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x3xf32>
+ return
+}
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fuse_same_indices_by_affine_apply
+// CHECK-SAME: (%[[ARG0:.+]]: memref<2x2xf32>, %[[ARG1:.+]]: memref<2x2xf32>) {
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
----------------
Hardcode84 wrote:
nit: SSA names are not guaranteed to be numbers (op can overload asmInterface to generate arbitrary names), so it's better to use `.*`
https://github.com/llvm/llvm-project/pull/80145
More information about the Mlir-commits
mailing list