[Mlir-commits] [mlir] 11b67aa - [mlir][scf] NFC - refactor the implementation of outlineIfOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jan 5 02:02:45 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-05T05:02:26-05:00
New Revision: 11b67aaffb0125f7c996abdf2b2fc2c61334f462

URL: https://github.com/llvm/llvm-project/commit/11b67aaffb0125f7c996abdf2b2fc2c61334f462
DIFF: https://github.com/llvm/llvm-project/commit/11b67aaffb0125f7c996abdf2b2fc2c61334f462.diff

LOG: [mlir][scf] NFC - refactor the implementation of outlineIfOp

This revision refactors the implementation of outlineIfOp to expose
a finer-grain functionality `outlineSingleBlockRegion` that will be
reused in other contexts.

Differential Revision: https://reviews.llvm.org/D116591

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Utils.h
    mlir/lib/Dialect/SCF/Transforms/Utils.cpp
    mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
index 7ed3e8c03f163..a062783d0bf60 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -13,17 +13,20 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_H_
 
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 class FuncOp;
+class Location;
 class Operation;
 class OpBuilder;
+class Region;
+class RewriterBase;
 class ValueRange;
 class Value;
-class AffineExpr;
-class Operation;
 
 namespace scf {
 class IfOp;
@@ -55,16 +58,34 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
                               ValueRange newYieldedValues,
                               bool replaceLoopResults = true);
 
+/// Outline a region with a single block into a new FuncOp.
+/// Assumes the FuncOp result types is the type of the yielded operands of the
+/// single block. This constraint makes it easy to determine the result.
+/// This method also clones the `arith::ConstantIndexOp` at the start of
+/// `outlinedFuncBody` to alloc simple canonicalizations.
+/// Creates a new FuncOp and thus cannot be used in a FunctionPass.
+/// The client is responsible for providing a unique `funcName` that will not
+/// collide with another FuncOp name.
+// TODO: support more than single-block regions.
+// TODO: more flexible constant handling.
+FailureOr<FuncOp> outlineSingleBlockRegion(RewriterBase &rewriter, Location loc,
+                                           Region &region, StringRef funcName);
+
 /// Outline the then and/or else regions of `ifOp` as follows:
 ///  - if `thenFn` is not null, `thenFnName` must be specified and the `then`
 ///    region is inlined into a new FuncOp that is captured by the pointer.
 ///  - if `elseFn` is not null, `elseFnName` must be specified and the `else`
 ///    region is inlined into a new FuncOp that is captured by the pointer.
-void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
-                 StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
+/// Creates new FuncOps and thus cannot be used in a FunctionPass.
+/// The client is responsible for providing a unique `thenFnName`/`elseFnName`
+/// that will not collide with another FuncOp name.
+LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn,
+                          StringRef thenFnName, FuncOp *elseFn,
+                          StringRef elseFnName);
 
-/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel
-/// loops are those that do not contain further parallel loops themselves.
+/// Get a list of innermost parallel loops contained in `rootOp`. Innermost
+/// parallel loops are those that do not contain further parallel loops
+/// themselves.
 bool getInnermostParallelLoops(Operation *rootOp,
                                SmallVectorImpl<scf::ParallelOp> &result);
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index c8e557cbe5afa..abd7408f86cd1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -12,12 +12,15 @@
 
 #include "mlir/Dialect/SCF/Utils.h"
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/RegionUtils.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
@@ -77,51 +80,124 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
   return newLoop;
 }
 
-void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
-                       StringRef thenFnName, FuncOp *elseFn,
-                       StringRef elseFnName) {
-  Location loc = ifOp.getLoc();
-  MLIRContext *ctx = ifOp.getContext();
-  auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
-    assert(!funcName.empty() && "Expected function name for outlining");
-    assert(ifOrElseRegion.getBlocks().size() <= 1 &&
-           "Expected at most one block");
-
-    // Outline before current function.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(ifOp->getParentOfType<FuncOp>());
-
-    SetVector<Value> captures;
-    getUsedValuesDefinedAbove(ifOrElseRegion, captures);
-
-    ValueRange values(captures.getArrayRef());
-    FunctionType type =
-        FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
-    auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
-    b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
+/// Outline a region with a single block into a new FuncOp.
+/// Assumes the FuncOp result types is the type of the yielded operands of the
+/// single block. This constraint makes it easy to determine the result.
+/// This method also clones the `arith::ConstantIndexOp` at the start of
+/// `outlinedFuncBody` to alloc simple canonicalizations.
+// TODO: support more than single-block regions.
+// TODO: more flexible constant handling.
+FailureOr<FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
+                                                 Location loc, Region &region,
+                                                 StringRef funcName) {
+  assert(!funcName.empty() && "funcName cannot be empty");
+  if (!region.hasOneBlock())
+    return failure();
+
+  Block *originalBlock = &region.front();
+  Operation *originalTerminator = originalBlock->getTerminator();
+
+  // Outline before current function.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(region.getParentOfType<FuncOp>());
+
+  SetVector<Value> captures;
+  getUsedValuesDefinedAbove(region, captures);
+
+  ValueRange outlinedValues(captures.getArrayRef());
+  SmallVector<Type> outlinedFuncArgTypes;
+  // Region's arguments are exactly the first block's arguments as per
+  // Region::getArguments().
+  // Func's arguments are cat(regions's arguments, captures arguments).
+  llvm::append_range(outlinedFuncArgTypes, region.getArgumentTypes());
+  llvm::append_range(outlinedFuncArgTypes, outlinedValues.getTypes());
+  FunctionType outlinedFuncType =
+      FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
+                        originalTerminator->getOperandTypes());
+  auto outlinedFunc = rewriter.create<FuncOp>(loc, funcName, outlinedFuncType);
+  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
+
+  // Merge blocks while replacing the original block operands.
+  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
+  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
+  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToEnd(outlinedFuncBody);
+    rewriter.mergeBlocks(
+        originalBlock, outlinedFuncBody,
+        outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
+    // Explicitly set up a new ReturnOp terminator.
+    rewriter.setInsertionPointToEnd(outlinedFuncBody);
+    rewriter.create<ReturnOp>(loc, originalTerminator->getResultTypes(),
+                              originalTerminator->getOperands());
+  }
+
+  // Reconstruct the block that was deleted and add a
+  // terminator(call_results).
+  Block *newBlock = rewriter.createBlock(
+      &region, region.begin(),
+      TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments));
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToEnd(newBlock);
+    SmallVector<Value> callValues;
+    llvm::append_range(callValues, newBlock->getArguments());
+    llvm::append_range(callValues, outlinedValues);
+    Operation *call = rewriter.create<CallOp>(loc, outlinedFunc, callValues);
+
+    // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
+    // Clone `originalTerminator` to take the callOp results then erase it from
+    // `outlinedFuncBody`.
     BlockAndValueMapping bvm;
-    for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
-      bvm.map(std::get<0>(it), std::get<1>(it));
-    for (Operation &op : ifOrElseRegion.front().without_terminator())
-      b.clone(op, bvm);
-
-    Operation *term = ifOrElseRegion.front().getTerminator();
-    SmallVector<Value, 4> terminatorOperands;
-    for (auto op : term->getOperands())
-      terminatorOperands.push_back(bvm.lookup(op));
-    b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
-
-    ifOrElseRegion.front().clear();
-    b.setInsertionPointToEnd(&ifOrElseRegion.front());
-    Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
-    b.create<scf::YieldOp>(loc, call->getResults());
-    return outlinedFunc;
-  };
-
-  if (thenFn && !ifOp.getThenRegion().empty())
-    *thenFn = outline(ifOp.getThenRegion(), thenFnName);
-  if (elseFn && !ifOp.getElseRegion().empty())
-    *elseFn = outline(ifOp.getElseRegion(), elseFnName);
+    bvm.map(originalTerminator->getOperands(), call->getResults());
+    rewriter.clone(*originalTerminator, bvm);
+    rewriter.eraseOp(originalTerminator);
+  }
+
+  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
+  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
+  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
+                                               outlinedValues.size()))) {
+    Value orig = std::get<0>(it);
+    Value repl = std::get<1>(it);
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPointToStart(outlinedFuncBody);
+      if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
+        BlockAndValueMapping bvm;
+        repl = rewriter.clone(*cst, bvm)->getResult(0);
+      }
+    }
+    orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
+      return outlinedFunc->isProperAncestor(opOperand.getOwner());
+    });
+  }
+
+  return outlinedFunc;
+}
+
+LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn,
+                                StringRef thenFnName, FuncOp *elseFn,
+                                StringRef elseFnName) {
+  IRRewriter rewriter(b);
+  Location loc = ifOp.getLoc();
+  FailureOr<FuncOp> outlinedFuncOpOrFailure;
+  if (thenFn && !ifOp.getThenRegion().empty()) {
+    outlinedFuncOpOrFailure = outlineSingleBlockRegion(
+        rewriter, loc, ifOp.getThenRegion(), thenFnName);
+    if (failed(outlinedFuncOpOrFailure))
+      return failure();
+    *thenFn = *outlinedFuncOpOrFailure;
+  }
+  if (elseFn && !ifOp.getElseRegion().empty()) {
+    outlinedFuncOpOrFailure = outlineSingleBlockRegion(
+        rewriter, loc, ifOp.getElseRegion(), elseFnName);
+    if (failed(outlinedFuncOpOrFailure))
+      return failure();
+    *elseFn = *outlinedFuncOpOrFailure;
+  }
+  return success();
 }
 
 bool mlir::getInnermostParallelLoops(Operation *rootOp,

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 47136d228fb35..56f54f5e7dd92 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -59,21 +59,26 @@ class TestSCFForUtilsPass
 };
 
 class TestSCFIfUtilsPass
-    : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
+    : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
 public:
   StringRef getArgument() const final { return "test-scf-if-utils"; }
   StringRef getDescription() const final { return "test scf.if utils"; }
   explicit TestSCFIfUtilsPass() = default;
 
-  void runOnFunction() override {
+  void runOnOperation() override {
     int count = 0;
-    FuncOp func = getFunction();
-    func.walk([&](scf::IfOp ifOp) {
+    getOperation().walk([&](scf::IfOp ifOp) {
       auto strCount = std::to_string(count++);
       FuncOp thenFn, elseFn;
       OpBuilder b(ifOp);
-      outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
-                  &elseFn, std::string("outlined_else") + strCount);
+      IRRewriter rewriter(b);
+      if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
+                             std::string("outlined_then") + strCount, &elseFn,
+                             std::string("outlined_else") + strCount))) {
+        this->signalPassFailure();
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
     });
   }
 };


        


More information about the Mlir-commits mailing list