[Mlir-commits] [mlir] 4184018 - [mlir][SCF] Canonicalize nested ParallelOp's

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 22 04:00:50 PDT 2021


Author: Butygin
Date: 2021-05-22T14:00:00+03:00
New Revision: 4184018253e720b0f2449b2b83ce27fc682f8579

URL: https://github.com/llvm/llvm-project/commit/4184018253e720b0f2449b2b83ce27fc682f8579
DIFF: https://github.com/llvm/llvm-project/commit/4184018253e720b0f2449b2b83ce27fc682f8579.diff

LOG: [mlir][SCF] Canonicalize nested ParallelOp's

Differential Revision: https://reviews.llvm.org/D102799

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 91f1e7a3e7c0d..c7b2836e04386 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1705,11 +1705,70 @@ struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
   }
 };
 
+struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
+  using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ParallelOp op,
+                                PatternRewriter &rewriter) const override {
+    Block &outerBody = op.getLoopBody().front();
+    if (!llvm::hasSingleElement(outerBody.without_terminator()))
+      return failure();
+
+    auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
+    if (!innerOp)
+      return failure();
+
+    auto hasVal = [](const auto &range, Value val) {
+      return llvm::find(range, val) != range.end();
+    };
+
+    for (auto val : outerBody.getArguments())
+      if (hasVal(innerOp.lowerBound(), val) ||
+          hasVal(innerOp.upperBound(), val) || hasVal(innerOp.step(), val))
+        return failure();
+
+    // Reductions are not supported yet.
+    if (!op.initVals().empty() || !innerOp.initVals().empty())
+      return failure();
+
+    auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
+                           ValueRange iterVals, ValueRange) {
+      Block &innerBody = innerOp.getLoopBody().front();
+      assert(iterVals.size() ==
+             (outerBody.getNumArguments() + innerBody.getNumArguments()));
+      BlockAndValueMapping mapping;
+      mapping.map(outerBody.getArguments(),
+                  iterVals.take_front(outerBody.getNumArguments()));
+      mapping.map(innerBody.getArguments(),
+                  iterVals.take_back(innerBody.getNumArguments()));
+      for (Operation &op : innerBody.without_terminator())
+        builder.clone(op, mapping);
+    };
+
+    auto concatValues = [](const auto &first, const auto &second) {
+      SmallVector<Value> ret;
+      ret.reserve(first.size() + second.size());
+      ret.assign(first.begin(), first.end());
+      ret.append(second.begin(), second.end());
+      return ret;
+    };
+
+    auto newLowerBounds = concatValues(op.lowerBound(), innerOp.lowerBound());
+    auto newUpperBounds = concatValues(op.upperBound(), innerOp.upperBound());
+    auto newSteps = concatValues(op.step(), innerOp.step());
+
+    rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
+                                            newSteps, llvm::None, bodyBuilder);
+    return success();
+  }
+};
+
 } // namespace
 
 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(context);
+  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
+              MergeNestedParallelLoops>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8f12c90b7729d..6b8867a7a9ce2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -99,6 +99,41 @@ func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
 
 // -----
 
+func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %1 = memref.dim %0, %c0 : memref<?x?x?xf64>
+  %2 = memref.dim %0, %c1 : memref<?x?x?xf64>
+  %3 = memref.dim %0, %c2 : memref<?x?x?xf64>
+  %4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf64>
+  scf.parallel (%arg1) = (%c0) to (%1) step (%c1) {
+    scf.parallel (%arg2) = (%c0) to (%2) step (%c1) {
+      scf.parallel (%arg3) = (%c0) to (%3) step (%c1) {
+        %5 = memref.load %0[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
+        memref.store %5, %4[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
+        scf.yield
+      }
+      scf.yield
+    }
+    scf.yield
+  }
+  return %4 : memref<?x?x?xf64>
+}
+
+// CHECK-LABEL:   func @nested_parallel(
+// CHECK:           [[C0:%.*]] = constant 0 : index
+// CHECK:           [[C1:%.*]] = constant 1 : index
+// CHECK:           [[C2:%.*]] = constant 2 : index
+// CHECK:           [[B0:%.*]] = memref.dim {{.*}}, [[C0]]
+// CHECK:           [[B1:%.*]] = memref.dim {{.*}}, [[C1]]
+// CHECK:           [[B2:%.*]] = memref.dim {{.*}}, [[C2]]
+// CHECK:           scf.parallel ([[V0:%.*]], [[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[B0]], [[B1]], [[B2]]) step ([[C1]], [[C1]], [[C1]])
+// CHECK:           memref.load {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
+// CHECK:           memref.store {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
+
+// -----
+
 func private @side_effect()
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
@@ -632,7 +667,7 @@ func @cond_prop(%arg0 : i1) -> index {
     } else {
       %v2 = "test.get_some_value"() : () -> i32
       scf.yield %c2 : index
-    } 
+    }
     scf.yield %res1 : index
   } else {
     %res2 = scf.if %arg0 -> index {
@@ -641,7 +676,7 @@ func @cond_prop(%arg0 : i1) -> index {
     } else {
       %v4 = "test.get_some_value"() : () -> i32
       scf.yield %c4 : index
-    } 
+    }
     scf.yield %res2 : index
   }
   return %res : index


        


More information about the Mlir-commits mailing list