[Mlir-commits] [mlir] a6b2877 - [MLIR] Make ParallelLoopFusion pass scan through all nested regions.

Alexander Belyaev llvmlistbot at llvm.org
Thu May 7 04:47:49 PDT 2020


Author: Alexander Belyaev
Date: 2020-05-07T13:47:30+02:00
New Revision: a6b2877f4c6b7e2e4776f00151c5451dca68027a

URL: https://github.com/llvm/llvm-project/commit/a6b2877f4c6b7e2e4776f00151c5451dca68027a
DIFF: https://github.com/llvm/llvm-project/commit/a6b2877f4c6b7e2e4776f00151c5451dca68027a.diff

LOG: [MLIR] Make ParallelLoopFusion pass scan through all nested regions.

Differential Revision: https://reviews.llvm.org/D79558

Added: 
    

Modified: 
    mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
    mlir/test/Dialect/Loops/parallel-loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
index afd32b2069a8..f16a11851771 100644
--- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
@@ -162,11 +162,12 @@ namespace {
 struct ParallelLoopFusion
     : public LoopParallelLoopFusionBase<ParallelLoopFusion> {
   void runOnOperation() override {
-    for (Region &region : getOperation()->getRegions())
-      naivelyFuseParallelOps(region);
+    getOperation()->walk([&](Operation *child) {
+      for (Region &region : child->getRegions())
+        naivelyFuseParallelOps(region);
+    });
   }
 };
-
 } // namespace
 
 std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {

diff  --git a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
index 13993304c62a..f9c59b44354a 100644
--- a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir
@@ -307,3 +307,53 @@ func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
 // CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
 // CHECK:        loop.parallel
 // CHECK:        loop.parallel
+
+// -----
+
+func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
+                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %sum = alloc()  : memref<2x2xf32>
+  loop.parallel (%k) = (%c0) to (%c2) step (%c1) {
+    loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      %B_elem = load %B[%i, %j] : memref<2x2xf32>
+      %C_elem = load %C[%i, %j] : memref<2x2xf32>
+      %sum_elem = addf %B_elem, %C_elem : f32
+      store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+      loop.yield
+    }
+    loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      %sum_elem = load %sum[%i, %j] : memref<2x2xf32>
+      %A_elem = load %A[%i, %j] : memref<2x2xf32>
+      %product_elem = mulf %sum_elem, %A_elem : f32
+      store %product_elem, %result[%i, %j] : memref<2x2xf32>
+      loop.yield
+    }
+  }
+  dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @nested_fuse
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
+// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
+// CHECK:      [[C2:%.*]] = constant 2 : index
+// CHECK:      [[C0:%.*]] = constant 0 : index
+// CHECK:      [[C1:%.*]] = constant 1 : index
+// CHECK:      [[SUM:%.*]] = alloc()
+// CHECK:      loop.parallel
+// CHECK:        loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:          [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:          [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
+// CHECK:          [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
+// CHECK:          store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:          [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK:          [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:          [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
+// CHECK:          store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:          loop.yield
+// CHECK:        }
+// CHECK:      }
+// CHECK:      dealloc [[SUM]]


        


More information about the Mlir-commits mailing list