[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 05:54:39 PST 2024
================
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}
-
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0fp = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK: [[C0:%.*]] = arith.constant 0 : index
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C2:%.*]] = arith.constant 2 : index
+// CHECK: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0fp = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK: [[C0:%.*]] = arith.constant 0 : index
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C2:%.*]] = arith.constant 2 : index
+// CHECK: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
----------------
fabrizio-indirli wrote:
**✓** Done
https://github.com/llvm/llvm-project/pull/79187
More information about the Mlir-commits
mailing list