[Mlir-commits] [mlir] 11b67aa - [mlir][scf] NFC - refactor the implementation of outlineIfOp
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jan 5 02:02:45 PST 2022
Author: Nicolas Vasilache
Date: 2022-01-05T05:02:26-05:00
New Revision: 11b67aaffb0125f7c996abdf2b2fc2c61334f462
URL: https://github.com/llvm/llvm-project/commit/11b67aaffb0125f7c996abdf2b2fc2c61334f462
DIFF: https://github.com/llvm/llvm-project/commit/11b67aaffb0125f7c996abdf2b2fc2c61334f462.diff
LOG: [mlir][scf] NFC - refactor the implementation of outlineIfOp
This revision refactors the implementation of outlineIfOp to expose
a finer-grain functionality `outlineSingleBlockRegion` that will be
reused in other contexts.
Differential Revision: https://reviews.llvm.org/D116591
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Utils.h
mlir/lib/Dialect/SCF/Transforms/Utils.cpp
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
index 7ed3e8c03f163..a062783d0bf60 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -13,17 +13,20 @@
#ifndef MLIR_DIALECT_SCF_UTILS_H_
#define MLIR_DIALECT_SCF_UTILS_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
class FuncOp;
+class Location;
class Operation;
class OpBuilder;
+class Region;
+class RewriterBase;
class ValueRange;
class Value;
-class AffineExpr;
-class Operation;
namespace scf {
class IfOp;
@@ -55,16 +58,34 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
ValueRange newYieldedValues,
bool replaceLoopResults = true);
+/// Outline a region with a single block into a new FuncOp.
+/// Assumes the FuncOp result types is the type of the yielded operands of the
+/// single block. This constraint makes it easy to determine the result.
+/// This method also clones the `arith::ConstantIndexOp` at the start of
+/// `outlinedFuncBody` to alloc simple canonicalizations.
+/// Creates a new FuncOp and thus cannot be used in a FunctionPass.
+/// The client is responsible for providing a unique `funcName` that will not
+/// collide with another FuncOp name.
+// TODO: support more than single-block regions.
+// TODO: more flexible constant handling.
+FailureOr<FuncOp> outlineSingleBlockRegion(RewriterBase &rewriter, Location loc,
+ Region ®ion, StringRef funcName);
+
/// Outline the then and/or else regions of `ifOp` as follows:
/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
/// region is inlined into a new FuncOp that is captured by the pointer.
/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
/// region is inlined into a new FuncOp that is captured by the pointer.
-void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
- StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
+/// Creates new FuncOps and thus cannot be used in a FunctionPass.
+/// The client is responsible for providing a unique `thenFnName`/`elseFnName`
+/// that will not collide with another FuncOp name.
+LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn,
+ StringRef thenFnName, FuncOp *elseFn,
+ StringRef elseFnName);
-/// Get a list of innermost parallel loops contained in `rootOp`. Innermost parallel
-/// loops are those that do not contain further parallel loops themselves.
+/// Get a list of innermost parallel loops contained in `rootOp`. Innermost
+/// parallel loops are those that do not contain further parallel loops
+/// themselves.
bool getInnermostParallelLoops(Operation *rootOp,
SmallVectorImpl<scf::ParallelOp> &result);
diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index c8e557cbe5afa..abd7408f86cd1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -12,12 +12,15 @@
#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
@@ -77,51 +80,124 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
return newLoop;
}
-void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
- StringRef thenFnName, FuncOp *elseFn,
- StringRef elseFnName) {
- Location loc = ifOp.getLoc();
- MLIRContext *ctx = ifOp.getContext();
- auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
- assert(!funcName.empty() && "Expected function name for outlining");
- assert(ifOrElseRegion.getBlocks().size() <= 1 &&
- "Expected at most one block");
-
- // Outline before current function.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(ifOp->getParentOfType<FuncOp>());
-
- SetVector<Value> captures;
- getUsedValuesDefinedAbove(ifOrElseRegion, captures);
-
- ValueRange values(captures.getArrayRef());
- FunctionType type =
- FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
- auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
- b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
+/// Outline a region with a single block into a new FuncOp.
+/// Assumes the FuncOp result types is the type of the yielded operands of the
+/// single block. This constraint makes it easy to determine the result.
+/// This method also clones the `arith::ConstantIndexOp` at the start of
+/// `outlinedFuncBody` to alloc simple canonicalizations.
+// TODO: support more than single-block regions.
+// TODO: more flexible constant handling.
+FailureOr<FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
+ Location loc, Region ®ion,
+ StringRef funcName) {
+ assert(!funcName.empty() && "funcName cannot be empty");
+ if (!region.hasOneBlock())
+ return failure();
+
+ Block *originalBlock = ®ion.front();
+ Operation *originalTerminator = originalBlock->getTerminator();
+
+ // Outline before current function.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(region.getParentOfType<FuncOp>());
+
+ SetVector<Value> captures;
+ getUsedValuesDefinedAbove(region, captures);
+
+ ValueRange outlinedValues(captures.getArrayRef());
+ SmallVector<Type> outlinedFuncArgTypes;
+ // Region's arguments are exactly the first block's arguments as per
+ // Region::getArguments().
+ // Func's arguments are cat(regions's arguments, captures arguments).
+ llvm::append_range(outlinedFuncArgTypes, region.getArgumentTypes());
+ llvm::append_range(outlinedFuncArgTypes, outlinedValues.getTypes());
+ FunctionType outlinedFuncType =
+ FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
+ originalTerminator->getOperandTypes());
+ auto outlinedFunc = rewriter.create<FuncOp>(loc, funcName, outlinedFuncType);
+ Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
+
+ // Merge blocks while replacing the original block operands.
+ // Warning: `mergeBlocks` erases the original block, reconstruct it later.
+ int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
+ auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToEnd(outlinedFuncBody);
+ rewriter.mergeBlocks(
+ originalBlock, outlinedFuncBody,
+ outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
+ // Explicitly set up a new ReturnOp terminator.
+ rewriter.setInsertionPointToEnd(outlinedFuncBody);
+ rewriter.create<ReturnOp>(loc, originalTerminator->getResultTypes(),
+ originalTerminator->getOperands());
+ }
+
+ // Reconstruct the block that was deleted and add a
+ // terminator(call_results).
+ Block *newBlock = rewriter.createBlock(
+ ®ion, region.begin(),
+ TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments));
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToEnd(newBlock);
+ SmallVector<Value> callValues;
+ llvm::append_range(callValues, newBlock->getArguments());
+ llvm::append_range(callValues, outlinedValues);
+ Operation *call = rewriter.create<CallOp>(loc, outlinedFunc, callValues);
+
+ // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
+ // Clone `originalTerminator` to take the callOp results then erase it from
+ // `outlinedFuncBody`.
BlockAndValueMapping bvm;
- for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
- bvm.map(std::get<0>(it), std::get<1>(it));
- for (Operation &op : ifOrElseRegion.front().without_terminator())
- b.clone(op, bvm);
-
- Operation *term = ifOrElseRegion.front().getTerminator();
- SmallVector<Value, 4> terminatorOperands;
- for (auto op : term->getOperands())
- terminatorOperands.push_back(bvm.lookup(op));
- b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
-
- ifOrElseRegion.front().clear();
- b.setInsertionPointToEnd(&ifOrElseRegion.front());
- Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
- b.create<scf::YieldOp>(loc, call->getResults());
- return outlinedFunc;
- };
-
- if (thenFn && !ifOp.getThenRegion().empty())
- *thenFn = outline(ifOp.getThenRegion(), thenFnName);
- if (elseFn && !ifOp.getElseRegion().empty())
- *elseFn = outline(ifOp.getElseRegion(), elseFnName);
+ bvm.map(originalTerminator->getOperands(), call->getResults());
+ rewriter.clone(*originalTerminator, bvm);
+ rewriter.eraseOp(originalTerminator);
+ }
+
+ // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
+ // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
+ for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
+ outlinedValues.size()))) {
+ Value orig = std::get<0>(it);
+ Value repl = std::get<1>(it);
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(outlinedFuncBody);
+ if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
+ BlockAndValueMapping bvm;
+ repl = rewriter.clone(*cst, bvm)->getResult(0);
+ }
+ }
+ orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
+ return outlinedFunc->isProperAncestor(opOperand.getOwner());
+ });
+ }
+
+ return outlinedFunc;
+}
+
+LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn,
+ StringRef thenFnName, FuncOp *elseFn,
+ StringRef elseFnName) {
+ IRRewriter rewriter(b);
+ Location loc = ifOp.getLoc();
+ FailureOr<FuncOp> outlinedFuncOpOrFailure;
+ if (thenFn && !ifOp.getThenRegion().empty()) {
+ outlinedFuncOpOrFailure = outlineSingleBlockRegion(
+ rewriter, loc, ifOp.getThenRegion(), thenFnName);
+ if (failed(outlinedFuncOpOrFailure))
+ return failure();
+ *thenFn = *outlinedFuncOpOrFailure;
+ }
+ if (elseFn && !ifOp.getElseRegion().empty()) {
+ outlinedFuncOpOrFailure = outlineSingleBlockRegion(
+ rewriter, loc, ifOp.getElseRegion(), elseFnName);
+ if (failed(outlinedFuncOpOrFailure))
+ return failure();
+ *elseFn = *outlinedFuncOpOrFailure;
+ }
+ return success();
}
bool mlir::getInnermostParallelLoops(Operation *rootOp,
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 47136d228fb35..56f54f5e7dd92 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -59,21 +59,26 @@ class TestSCFForUtilsPass
};
class TestSCFIfUtilsPass
- : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
+ : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
public:
StringRef getArgument() const final { return "test-scf-if-utils"; }
StringRef getDescription() const final { return "test scf.if utils"; }
explicit TestSCFIfUtilsPass() = default;
- void runOnFunction() override {
+ void runOnOperation() override {
int count = 0;
- FuncOp func = getFunction();
- func.walk([&](scf::IfOp ifOp) {
+ getOperation().walk([&](scf::IfOp ifOp) {
auto strCount = std::to_string(count++);
FuncOp thenFn, elseFn;
OpBuilder b(ifOp);
- outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
- &elseFn, std::string("outlined_else") + strCount);
+ IRRewriter rewriter(b);
+ if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
+ std::string("outlined_then") + strCount, &elseFn,
+ std::string("outlined_else") + strCount))) {
+ this->signalPassFailure();
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
});
}
};
More information about the Mlir-commits
mailing list