[Mlir-commits] [mlir] f202d32 - [mlir][SCF] Add canonicalization pattern for scf::For to eliminate yields that just forward.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 4 03:37:21 PST 2020
Author: Nicolas Vasilache
Date: 2020-11-04T11:36:27Z
New Revision: f202d32216c64b1ae8853a0506b85674cf52126a
URL: https://github.com/llvm/llvm-project/commit/f202d32216c64b1ae8853a0506b85674cf52126a
DIFF: https://github.com/llvm/llvm-project/commit/f202d32216c64b1ae8853a0506b85674cf52126a.diff
LOG: [mlir][SCF] Add canonicalization pattern for scf::For to eliminate yields that just forward.
For instance:
```
func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
%a = call @make_i32() : () -> (i32)
%b = call @make_i32() : () -> (i32)
%r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
%c = call @make_i32() : () -> (i32)
scf.yield %0, %c, %2 : i32, i32, i32
}
return %r#0, %r#1, %r#2 : i32, i32, i32
}
```
Canonicalizes as:
```
func @for_yields_3(%arg0: index, %arg1: index, %arg2: index) -> (i32, i32, i32) {
%0 = call @make_i32() : () -> i32
%1 = call @make_i32() : () -> i32
%2 = scf.for %arg3 = %arg0 to %arg1 step %arg2 iter_args(%arg4 = %0) -> (i32) {
%3 = call @make_i32() : () -> i32
scf.yield %3 : i32
}
return %0, %2, %1 : i32, i32, i32
}
```
Differential Revision: https://reviews.llvm.org/D90745
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index bf81d7ff2177..1dc6ef1f68a4 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -197,6 +197,8 @@ def ForOp : SCF_Op<"for",
/// value for `index`.
OperandRange getSuccessorEntryOperands(unsigned index);
}];
+
+ let hasCanonicalizer = 1;
}
def IfOp : SCF_Op<"if",
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 56932ff1f30e..ea607a3a402a 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -370,6 +370,120 @@ ValueVector mlir::scf::buildLoopNest(
});
}
+namespace {
+// Fold away ForOp iter arguments that are also yielded by the op.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
+//
+// The implementation uses `mergeBlockBefore` to steal the content of the
+// original ForOp and avoid cloning.
+struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter &rewriter) const final {
+ bool canonicalize = false;
+ Block &block = forOp.region().front();
+ auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
+
+ // An internal flat vector of block transfer
+ // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
+ // transformed block argument mappings. This plays the role of a
+ // BlockAndValueMapping for the particular use case of calling into
+ // `mergeBlockBefore`.
+ SmallVector<bool, 4> keepMask;
+ keepMask.reserve(yieldOp.getNumOperands());
+ SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
+ newResultValues;
+ newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
+ newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
+ newIterArgs.reserve(forOp.getNumIterOperands());
+ newYieldValues.reserve(yieldOp.getNumOperands());
+ newResultValues.reserve(forOp.getNumResults());
+ for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ yieldOp.getOperands() // iter yield
+ )) {
+ // Forwarded is `true` when the region `iter` argument is yielded.
+ bool forwarded = (std::get<1>(it) == std::get<2>(it));
+ keepMask.push_back(!forwarded);
+ canonicalize |= forwarded;
+ if (forwarded) {
+ newBlockTransferArgs.push_back(std::get<0>(it));
+ newResultValues.push_back(std::get<0>(it));
+ continue;
+ }
+ newIterArgs.push_back(std::get<0>(it));
+ newYieldValues.push_back(std::get<2>(it));
+ newBlockTransferArgs.push_back(Value()); // placeholder with null value
+ newResultValues.push_back(Value()); // placeholder with null value
+ }
+
+ if (!canonicalize)
+ return failure();
+
+ scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
+ newIterArgs);
+ Block &newBlock = newForOp.region().front();
+
+ // Replace the null placeholders with newly constructed values.
+ newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
+ for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
+ idx != e; ++idx) {
+ Value &blockTransferArg = newBlockTransferArgs[1 + idx];
+ Value &newResultVal = newResultValues[idx];
+ assert((blockTransferArg && newResultVal) ||
+ (!blockTransferArg && !newResultVal));
+ if (!blockTransferArg) {
+ blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
+ newResultVal = newForOp.getResult(collapsedIdx++);
+ }
+ }
+
+ Block &oldBlock = forOp.region().front();
+ assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
+ "unexpected argument size mismatch");
+
+ // No results case: the scf::ForOp builder already created a zero
+ // reult terminator. Merge before this terminator and just get rid of the
+ // original terminator that has been merged in.
+ if (newIterArgs.empty()) {
+ auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
+ rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
+ rewriter.replaceOp(forOp, newResultValues);
+ return success();
+ }
+
+ // No terminator case: merge and rewrite the merged terminator.
+ auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(mergedTerminator);
+ SmallVector<Value, 4> filteredOperands;
+ filteredOperands.reserve(newResultValues.size());
+ for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
+ if (keepMask[idx])
+ filteredOperands.push_back(mergedTerminator.getOperand(idx));
+ rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
+ filteredOperands);
+ };
+
+ rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+ auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ cloneFilteredTerminator(mergedYieldOp);
+ rewriter.eraseOp(mergedYieldOp);
+ rewriter.replaceOp(forOp, newResultValues);
+ return success();
+ }
+};
+} // namespace
+
+void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ForOpIterArgsFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index a96786076109..bd91f3bde403 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -137,3 +137,38 @@ func @all_unused() {
// CHECK: call @side_effect() : () -> ()
// CHECK: }
// CHECK: return
+
+// -----
+
+func @make_i32() -> i32
+
+func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
+ %a = call @make_i32() : () -> (i32)
+ %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 {
+ scf.yield %0 : i32
+ }
+ return %b : i32
+}
+
+// CHECK-LABEL: func @for_yields_2
+// CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32
+// CHECK-NEXT: return %[[R]] : i32
+
+func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
+ %a = call @make_i32() : () -> (i32)
+ %b = call @make_i32() : () -> (i32)
+ %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
+ %c = call @make_i32() : () -> (i32)
+ scf.yield %0, %c, %2 : i32, i32, i32
+ }
+ return %r#0, %r#1, %r#2 : i32, i32, i32
+}
+
+// CHECK-LABEL: func @for_yields_3
+// CHECK-NEXT: %[[a:.*]] = call @make_i32() : () -> i32
+// CHECK-NEXT: %[[b:.*]] = call @make_i32() : () -> i32
+// CHECK-NEXT: %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
+// CHECK-NEXT: %[[c:.*]] = call @make_i32() : () -> i32
+// CHECK-NEXT: scf.yield %[[c]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32
More information about the Mlir-commits
mailing list