[Mlir-commits] [mlir] f202d32 - [mlir][SCF] Add canonicalization pattern for scf::For to eliminate yields that just forward.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 4 03:37:21 PST 2020


Author: Nicolas Vasilache
Date: 2020-11-04T11:36:27Z
New Revision: f202d32216c64b1ae8853a0506b85674cf52126a

URL: https://github.com/llvm/llvm-project/commit/f202d32216c64b1ae8853a0506b85674cf52126a
DIFF: https://github.com/llvm/llvm-project/commit/f202d32216c64b1ae8853a0506b85674cf52126a.diff

LOG: [mlir][SCF] Add canonicalization pattern for scf::For to eliminate yields that just forward.

For instance:
```
func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
  %a = call @make_i32() : () -> (i32)
  %b = call @make_i32() : () -> (i32)
  %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
    %c = call @make_i32() : () -> (i32)
    scf.yield %0, %c, %2 : i32, i32, i32
  }
  return %r#0, %r#1, %r#2 : i32, i32, i32
}
```

Canonicalizes as:
```
  func @for_yields_3(%arg0: index, %arg1: index, %arg2: index) -> (i32, i32, i32) {
    %0 = call @make_i32() : () -> i32
    %1 = call @make_i32() : () -> i32
    %2 = scf.for %arg3 = %arg0 to %arg1 step %arg2 iter_args(%arg4 = %0) -> (i32) {
      %3 = call @make_i32() : () -> i32
      scf.yield %3 : i32
    }
    return %0, %2, %1 : i32, i32, i32
  }
```

Differential Revision: https://reviews.llvm.org/D90745

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index bf81d7ff2177..1dc6ef1f68a4 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -197,6 +197,8 @@ def ForOp : SCF_Op<"for",
     /// value for `index`.
     OperandRange getSuccessorEntryOperands(unsigned index);
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def IfOp : SCF_Op<"if",

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 56932ff1f30e..ea607a3a402a 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -370,6 +370,120 @@ ValueVector mlir::scf::buildLoopNest(
                        });
 }
 
+namespace {
+// Fold away ForOp iter arguments that are also yielded by the op.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
+//
+// The implementation uses `mergeBlockBefore` to steal the content of the
+// original ForOp and avoid cloning.
+struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
+  using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::ForOp forOp,
+                                PatternRewriter &rewriter) const final {
+    bool canonicalize = false;
+    Block &block = forOp.region().front();
+    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
+
+    // An internal flat vector of block transfer
+    // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
+    // transformed block argument mappings. This plays the role of a
+    // BlockAndValueMapping for the particular use case of calling into
+    // `mergeBlockBefore`.
+    SmallVector<bool, 4> keepMask;
+    keepMask.reserve(yieldOp.getNumOperands());
+    SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
+        newResultValues;
+    newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
+    newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
+    newIterArgs.reserve(forOp.getNumIterOperands());
+    newYieldValues.reserve(yieldOp.getNumOperands());
+    newResultValues.reserve(forOp.getNumResults());
+    for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
+                             forOp.getRegionIterArgs(), // iter inside region
+                             yieldOp.getOperands()      // iter yield
+                             )) {
+      // Forwarded is `true` when the region `iter` argument is yielded.
+      bool forwarded = (std::get<1>(it) == std::get<2>(it));
+      keepMask.push_back(!forwarded);
+      canonicalize |= forwarded;
+      if (forwarded) {
+        newBlockTransferArgs.push_back(std::get<0>(it));
+        newResultValues.push_back(std::get<0>(it));
+        continue;
+      }
+      newIterArgs.push_back(std::get<0>(it));
+      newYieldValues.push_back(std::get<2>(it));
+      newBlockTransferArgs.push_back(Value()); // placeholder with null value
+      newResultValues.push_back(Value());      // placeholder with null value
+    }
+
+    if (!canonicalize)
+      return failure();
+
+    scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+        forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
+        newIterArgs);
+    Block &newBlock = newForOp.region().front();
+
+    // Replace the null placeholders with newly constructed values.
+    newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
+    for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
+         idx != e; ++idx) {
+      Value &blockTransferArg = newBlockTransferArgs[1 + idx];
+      Value &newResultVal = newResultValues[idx];
+      assert((blockTransferArg && newResultVal) ||
+             (!blockTransferArg && !newResultVal));
+      if (!blockTransferArg) {
+        blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
+        newResultVal = newForOp.getResult(collapsedIdx++);
+      }
+    }
+
+    Block &oldBlock = forOp.region().front();
+    assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
+           "unexpected argument size mismatch");
+
+    // No results case: the scf::ForOp builder already created a zero
+    // reult terminator. Merge before this terminator and just get rid of the
+    // original terminator that has been merged in.
+    if (newIterArgs.empty()) {
+      auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+      rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
+      rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
+      rewriter.replaceOp(forOp, newResultValues);
+      return success();
+    }
+
+    // No terminator case: merge and rewrite the merged terminator.
+    auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(mergedTerminator);
+      SmallVector<Value, 4> filteredOperands;
+      filteredOperands.reserve(newResultValues.size());
+      for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
+        if (keepMask[idx])
+          filteredOperands.push_back(mergedTerminator.getOperand(idx));
+      rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
+                                    filteredOperands);
+    };
+
+    rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+    auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+    cloneFilteredTerminator(mergedYieldOp);
+    rewriter.eraseOp(mergedYieldOp);
+    rewriter.replaceOp(forOp, newResultValues);
+    return success();
+  }
+};
+} // namespace
+
+void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                        MLIRContext *context) {
+  results.insert<ForOpIterArgsFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // IfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index a96786076109..bd91f3bde403 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -137,3 +137,38 @@ func @all_unused() {
 // CHECK:             call @side_effect() : () -> ()
 // CHECK:           }
 // CHECK:           return
+
+// -----
+
+func @make_i32() -> i32
+
+func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {
+  %a = call @make_i32() : () -> (i32)
+  %b = scf.for %i = %lb to %ub step %step iter_args(%0 = %a) -> i32 {
+    scf.yield %0 : i32
+  }
+  return %b : i32
+}
+
+// CHECK-LABEL:   func @for_yields_2
+//  CHECK-NEXT:     %[[R:.*]] = call @make_i32() : () -> i32
+//  CHECK-NEXT:     return %[[R]] : i32
+
+func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
+  %a = call @make_i32() : () -> (i32)
+  %b = call @make_i32() : () -> (i32)
+  %r:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %a, %2 = %b) -> (i32, i32, i32) {
+    %c = call @make_i32() : () -> (i32)
+    scf.yield %0, %c, %2 : i32, i32, i32
+  }
+  return %r#0, %r#1, %r#2 : i32, i32, i32
+}
+
+// CHECK-LABEL:   func @for_yields_3
+//  CHECK-NEXT:     %[[a:.*]] = call @make_i32() : () -> i32
+//  CHECK-NEXT:     %[[b:.*]] = call @make_i32() : () -> i32
+//  CHECK-NEXT:     %[[r1:.*]] = scf.for {{.*}} iter_args(%arg4 = %[[a]]) -> (i32) {
+//  CHECK-NEXT:       %[[c:.*]] = call @make_i32() : () -> i32
+//  CHECK-NEXT:       scf.yield %[[c]] : i32
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:     return %[[a]], %[[r1]], %[[b]] : i32, i32, i32


        


More information about the Mlir-commits mailing list