[Mlir-commits] [mlir] d17b005 - [mlir][scf] Relax requirements for loops fusion (#79187)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 09:17:09 PST 2024


Author: fabrizio-indirli
Date: 2024-01-30T20:17:06+03:00
New Revision: d17b005e46e240b2c95801a14e2c6fc5baa5b3f7

URL: https://github.com/llvm/llvm-project/commit/d17b005e46e240b2c95801a14e2c6fc5baa5b3f7
DIFF: https://github.com/llvm/llvm-project/commit/d17b005e46e240b2c95801a14e2c6fc5baa5b3f7.diff

LOG: [mlir][scf] Relax requirements for loops fusion (#79187)

Enable the fusion of parallel loops also when the 1st loop contains
multiple write accesses to the same buffer, if the accesses are always
on the same indices.
Fix LIT test cases whose loops were not being fused.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
    mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c..8f2ab5f5e6dc1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -83,13 +83,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
     if (write == bufferStores.end())
       return WalkResult::advance();
 
-    // Allow only single write access per buffer.
-    if (write->second.size() != 1)
+    // Check that at last one store was retrieved
+    if (!write->second.size())
       return WalkResult::interrupt();
 
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
     // Check that the load indices of secondPloop coincide with store indices of
     // firstPloop for the same memrefs.
-    auto storeIndices = write->second.front();
     auto loadIndices = load.getIndices();
     if (storeIndices.size() != loadIndices.size())
       return WalkResult::interrupt();

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e52471..110168ba6eca5 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -13,9 +13,9 @@ func.func @fuse_empty_loops() {
   return
 }
 // CHECK-LABEL: func @fuse_empty_loops
-// CHECK:        [[C2:%.*]] = arith.constant 2 : index
-// CHECK:        [[C0:%.*]] = arith.constant 0 : index
-// CHECK:        [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:    [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:    [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:    [[C1:%.*]] = arith.constant 1 : index
 // CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
 // CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
 // CHECK:          scf.reduce
@@ -24,16 +24,15 @@ func.func @fuse_empty_loops() {
 
 // -----
 
-func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
-                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
   %sum = memref.alloc()  : memref<2x2xf32>
   scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
     %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
-    %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
-    %sum_elem = arith.addf %B_elem, %C_elem : f32
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
     memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
@@ -41,89 +40,90 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
     %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
     %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
     %product_elem = arith.mulf %sum_elem, %A_elem : f32
-    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
   memref.dealloc %sum : memref<2x2xf32>
   return
 }
 // CHECK-LABEL: func @fuse_two
-// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
-// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
-// CHECK:      [[C2:%.*]] = arith.constant 2 : index
-// CHECK:      [[C0:%.*]] = arith.constant 0 : index
-// CHECK:      [[C1:%.*]] = arith.constant 1 : index
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
 // CHECK:      [[SUM:%.*]] = memref.alloc()
 // CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
 // CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
 // CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
 // CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK:        memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
 // CHECK:        scf.reduce
 // CHECK:      }
 // CHECK:      memref.dealloc [[SUM]]
 
 // -----
 
-func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
-                      %result: memref<100x10xf32>) {
-  %c100 = arith.constant 100 : index
-  %c10 = arith.constant 10 : index
+func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %broadcast_rhs = memref.alloc() : memref<100x10xf32>
-  %
diff  = memref.alloc() : memref<100x10xf32>
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
-    %rhs_elem = memref.load %rhs[%i] : memref<100xf32>
-    memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
+  %c1fp = arith.constant 1.0 : f32
+  %c2fp = arith.constant 2.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  %prod = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %c1fp : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
-    %lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32>
-    %broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32>
-    %
diff _elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32
-    memref.store %
diff _elem, %
diff [%i, %j] : memref<100x10xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %c2fp : f32
+    memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32>
     scf.reduce
   }
-  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
-    %
diff _elem = memref.load %
diff [%i, %j] : memref<100x10xf32>
-    %exp_elem = math.exp %
diff _elem : f32
-    memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32>
-    scf.reduce
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %res_elem = arith.addf %A_elem, %c2fp : f32
+    memref.store %res_elem, %B[%i, %j] : memref<2x2xf32>
   }
-  memref.dealloc %broadcast_rhs : memref<100x10xf32>
-  memref.dealloc %
diff  : memref<100x10xf32>
+  memref.dealloc %sum : memref<2x2xf32>
+  memref.dealloc %prod : memref<2x2xf32>
   return
 }
 // CHECK-LABEL: func @fuse_three
-// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
-// CHECK-SAME:  [[RESULT:%.*]]: memref<100x10xf32>) {
-// CHECK:      [[C100:%.*]] = arith.constant 100 : index
-// CHECK:      [[C10:%.*]] = arith.constant 10 : index
-// CHECK:      [[C0:%.*]] = arith.constant 0 : index
-// CHECK:      [[C1:%.*]] = arith.constant 1 : index
-// CHECK:      [[BROADCAST_RHS:%.*]] = memref.alloc()
-// CHECK:      [[DIFF:%.*]] = memref.alloc()
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
+// CHECK-DAG:  [[C2FP:%.*]] = arith.constant 2.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      [[PROD:%.*]] = memref.alloc()
 // CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
-// CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
-// CHECK:        [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
-// CHECK:        memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
-// CHECK:        [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
-// CHECK:        memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
-// CHECK:        [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
-// CHECK:        memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[C2FP]]
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[PROD]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[RES_ELEM:%.*]] = arith.addf [[A_ELEM]], [[C2FP]]
+// CHECK:        memref.store [[RES_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
 // CHECK:        scf.reduce
 // CHECK:      }
-// CHECK:      memref.dealloc [[BROADCAST_RHS]]
-// CHECK:      memref.dealloc [[DIFF]]
+// CHECK:      memref.dealloc [[SUM]]
+// CHECK:      memref.dealloc [[PROD]]
 
 // -----
 
@@ -310,17 +310,16 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
 
 // -----
 
-func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
-                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
+  %c1fp = arith.constant 1.0 : f32
   %sum = memref.alloc()  : memref<2x2xf32>
   scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
     scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
       %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
-      %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
-      %sum_elem = arith.addf %B_elem, %C_elem : f32
+      %sum_elem = arith.addf %B_elem, %c1fp : f32
       memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
       scf.reduce
     }
@@ -328,7 +327,7 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
       %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
       %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
       %product_elem = arith.mulf %sum_elem, %A_elem : f32
-      memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+      memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
       scf.reduce
     }
   }
@@ -336,23 +335,23 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
   return
 }
 // CHECK-LABEL: func @nested_fuse
-// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
-// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
-// CHECK:      [[C2:%.*]] = arith.constant 2 : index
-// CHECK:      [[C0:%.*]] = arith.constant 0 : index
-// CHECK:      [[C1:%.*]] = arith.constant 1 : index
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C1FP:%.*]] = arith.constant 1.
 // CHECK:      [[SUM:%.*]] = memref.alloc()
 // CHECK:      scf.parallel
 // CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
 // CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
 // CHECK:          [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
-// CHECK:          [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
-// CHECK:          [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
+// CHECK:          [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]]
 // CHECK:          memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:   scf.parallel
 // CHECK:          [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
 // CHECK:          [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
 // CHECK:          [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
-// CHECK:          memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:          memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
 // CHECK:          scf.reduce
 // CHECK:        }
 // CHECK:      }
@@ -382,8 +381,102 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
   }
   return
 }
-
 // %sum and %result may alias with other args, do not fuse loops
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0fp = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0F32:%.*]] = arith.constant 0.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+// CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_
diff _indices(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0fp = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0fp, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %B[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_
diff _indices
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) {
+// CHECK-DAG:  [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG:  [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG:  [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG:  [[C0F32:%.*]] = arith.constant 0.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[C0]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:     scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK:        [[SUM_ELEM:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+// CHECK:      memref.dealloc [[SUM]]


        


More information about the Mlir-commits mailing list