[Mlir-commits] [mlir] d17b005 - [mlir][scf] Relax requirements for loops fusion (#79187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 09:17:09 PST 2024
Author: fabrizio-indirli
Date: 2024-01-30T20:17:06+03:00
New Revision: d17b005e46e240b2c95801a14e2c6fc5baa5b3f7
URL: https://github.com/llvm/llvm-project/commit/d17b005e46e240b2c95801a14e2c6fc5baa5b3f7
DIFF: https://github.com/llvm/llvm-project/commit/d17b005e46e240b2c95801a14e2c6fc5baa5b3f7.diff
LOG: [mlir][scf] Relax requirements for loops fusion (#79187)
Enable the fusion of parallel loops also when the 1st loop contains
multiple write accesses to the same buffer, if the accesses are always
on the same indices.
Fix LIT test cases whose loops were not being fused.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c..8f2ab5f5e6dc1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
if (write == bufferStores.end())
return WalkResult::advance();
- // Allow only single write access per buffer.
- if (write->second.size() != 1)
+ // Check that at last one store was retrieved
+ if (!write->second.size())
return WalkResult::interrupt();
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
- auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e52471..110168ba6eca5 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -13,9 +13,9 @@ func.func @fuse_empty_loops() {
return
}
// CHECK-LABEL: func @fuse_empty_loops
-// CHECK: [[C2:%.*]] = arith.constant 2 : index
-// CHECK: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: scf.reduce
@@ -24,16 +24,15 @@ func.func @fuse_empty_loops() {
// -----
-func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
- %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %C_elem : f32
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
@@ -41,89 +40,90 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
memref.dealloc %sum : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_two
-// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
-// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
-// CHECK: [[C2:%.*]] = arith.constant 2 : index
-// CHECK: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
-// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
// -----
-func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
- %result: memref<100x10xf32>) {
- %c100 = arith.constant 100 : index
- %c10 = arith.constant 10 : index
+func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %broadcast_rhs = memref.alloc() : memref<100x10xf32>
- %
diff = memref.alloc() : memref<100x10xf32>
- scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
- %rhs_elem = memref.load %rhs[%i] : memref<100xf32>
- memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
+ %c1fp = arith.constant 1.0 : f32
+ %c2fp = arith.constant 2.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ %prod = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
- scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
- %lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
- %broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
- %
diff _elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
- memref.store %
diff _elem, %
diff [%i, %j] : memref<100x10xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %c2fp : f32
+ memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
scf.reduce
}
- scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
- %
diff _elem = memref.load %
diff [%i, %j] : memref<100x10xf32>
- %exp_elem = math.exp %
diff _elem : f32
- memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
- scf.reduce
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %res_elem = arith.addf %A_elem, %c2fp : f32
+ memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
}
- memref.dealloc %broadcast_rhs : memref<100x10xf32>
- memref.dealloc %
diff : memref<100x10xf32>
+ memref.dealloc %sum : memref<2x2xf32>
+ memref.dealloc %prod : memref<2x2xf32>
return
}
// CHECK-LABEL: func @fuse_three
-// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
-// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) {
-// CHECK: [[C100:%.*]] = arith.constant 100 : index
-// CHECK: [[C10:%.*]] = arith.constant 10 : index
-// CHECK: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
-// CHECK: [[BROADCAST_RHS:%.*]] = memref.alloc()
-// CHECK: [[DIFF:%.*]] = memref.alloc()
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
+// CHECK-DAG: [[C2FP:%.*]] = arith.constant 2.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: [[PROD:%.*]] = memref.alloc()
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
-// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
-// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
-// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
-// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
-// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
-// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
-// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
-// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
-// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
-// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
+// CHECK: memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
+// CHECK: memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
-// CHECK: memref.dealloc [[BROADCAST_RHS]]
-// CHECK: memref.dealloc [[DIFF]]
+// CHECK: memref.dealloc [[SUM]]
+// CHECK: memref.dealloc [[PROD]]
// -----
@@ -310,17 +310,16 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
// -----
-func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
- %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
%sum = memref.alloc() : memref<2x2xf32>
scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
- %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
- %sum_elem = arith.addf %B_elem, %C_elem : f32
+ %sum_elem = arith.addf %B_elem, %c1fp : f32
memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
scf.reduce
}
@@ -328,7 +327,7 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
%sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
%product_elem = arith.mulf %sum_elem, %A_elem : f32
- memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
scf.reduce
}
}
@@ -336,23 +335,23 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
return
}
// CHECK-LABEL: func @nested_fuse
-// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
-// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
-// CHECK: [[C2:%.*]] = arith.constant 2 : index
-// CHECK: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1.
// CHECK: [[SUM:%.*]] = memref.alloc()
// CHECK: scf.parallel
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
-// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
// CHECK: scf.reduce
// CHECK: }
// CHECK: }
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}
-
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0fp = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_
diff _indices(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0fp = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_
diff _indices
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK: [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
More information about the Mlir-commits
mailing list