[Mlir-commits] [mlir] 4184018 - [mlir][SCF] Canonicalize nested ParallelOp's
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 22 04:00:50 PDT 2021
Author: Butygin
Date: 2021-05-22T14:00:00+03:00
New Revision: 4184018253e720b0f2449b2b83ce27fc682f8579
URL: https://github.com/llvm/llvm-project/commit/4184018253e720b0f2449b2b83ce27fc682f8579
DIFF: https://github.com/llvm/llvm-project/commit/4184018253e720b0f2449b2b83ce27fc682f8579.diff
LOG: [mlir][SCF] Canonicalize nested ParallelOp's
Differential Revision: https://reviews.llvm.org/D102799
Added:
Modified:
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 91f1e7a3e7c0d..c7b2836e04386 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1705,11 +1705,70 @@ struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
}
};
+struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
+ using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ParallelOp op,
+ PatternRewriter &rewriter) const override {
+ Block &outerBody = op.getLoopBody().front();
+ if (!llvm::hasSingleElement(outerBody.without_terminator()))
+ return failure();
+
+ auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
+ if (!innerOp)
+ return failure();
+
+ auto hasVal = [](const auto &range, Value val) {
+ return llvm::find(range, val) != range.end();
+ };
+
+ for (auto val : outerBody.getArguments())
+ if (hasVal(innerOp.lowerBound(), val) ||
+ hasVal(innerOp.upperBound(), val) || hasVal(innerOp.step(), val))
+ return failure();
+
+ // Reductions are not supported yet.
+ if (!op.initVals().empty() || !innerOp.initVals().empty())
+ return failure();
+
+ auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
+ ValueRange iterVals, ValueRange) {
+ Block &innerBody = innerOp.getLoopBody().front();
+ assert(iterVals.size() ==
+ (outerBody.getNumArguments() + innerBody.getNumArguments()));
+ BlockAndValueMapping mapping;
+ mapping.map(outerBody.getArguments(),
+ iterVals.take_front(outerBody.getNumArguments()));
+ mapping.map(innerBody.getArguments(),
+ iterVals.take_back(innerBody.getNumArguments()));
+ for (Operation &op : innerBody.without_terminator())
+ builder.clone(op, mapping);
+ };
+
+ auto concatValues = [](const auto &first, const auto &second) {
+ SmallVector<Value> ret;
+ ret.reserve(first.size() + second.size());
+ ret.assign(first.begin(), first.end());
+ ret.append(second.begin(), second.end());
+ return ret;
+ };
+
+ auto newLowerBounds = concatValues(op.lowerBound(), innerOp.lowerBound());
+ auto newUpperBounds = concatValues(op.upperBound(), innerOp.upperBound());
+ auto newSteps = concatValues(op.step(), innerOp.step());
+
+ rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
+ newSteps, llvm::None, bodyBuilder);
+ return success();
+ }
+};
+
} // namespace
void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(context);
+ results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
+ MergeNestedParallelLoops>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8f12c90b7729d..6b8867a7a9ce2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -99,6 +99,41 @@ func @single_iteration_reduce(%A: index, %B: index) -> (index, index) {
// -----
+func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %1 = memref.dim %0, %c0 : memref<?x?x?xf64>
+ %2 = memref.dim %0, %c1 : memref<?x?x?xf64>
+ %3 = memref.dim %0, %c2 : memref<?x?x?xf64>
+ %4 = memref.alloc(%1, %2, %3) : memref<?x?x?xf64>
+ scf.parallel (%arg1) = (%c0) to (%1) step (%c1) {
+ scf.parallel (%arg2) = (%c0) to (%2) step (%c1) {
+ scf.parallel (%arg3) = (%c0) to (%3) step (%c1) {
+ %5 = memref.load %0[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
+ memref.store %5, %4[%arg1, %arg2, %arg3] : memref<?x?x?xf64>
+ scf.yield
+ }
+ scf.yield
+ }
+ scf.yield
+ }
+ return %4 : memref<?x?x?xf64>
+}
+
+// CHECK-LABEL: func @nested_parallel(
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[C2:%.*]] = constant 2 : index
+// CHECK: [[B0:%.*]] = memref.dim {{.*}}, [[C0]]
+// CHECK: [[B1:%.*]] = memref.dim {{.*}}, [[C1]]
+// CHECK: [[B2:%.*]] = memref.dim {{.*}}, [[C2]]
+// CHECK: scf.parallel ([[V0:%.*]], [[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[B0]], [[B1]], [[B2]]) step ([[C1]], [[C1]], [[C1]])
+// CHECK: memref.load {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
+// CHECK: memref.store {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]]
+
+// -----
+
func private @side_effect()
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
@@ -632,7 +667,7 @@ func @cond_prop(%arg0 : i1) -> index {
} else {
%v2 = "test.get_some_value"() : () -> i32
scf.yield %c2 : index
- }
+ }
scf.yield %res1 : index
} else {
%res2 = scf.if %arg0 -> index {
@@ -641,7 +676,7 @@ func @cond_prop(%arg0 : i1) -> index {
} else {
%v4 = "test.get_some_value"() : () -> i32
scf.yield %c4 : index
- }
+ }
scf.yield %res2 : index
}
return %res : index
More information about the Mlir-commits
mailing list