[Mlir-commits] [mlir] 2a01d7f - [mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Aug 7 11:50:29 PDT 2020
Author: Nicolas Vasilache
Date: 2020-08-07T14:49:49-04:00
New Revision: 2a01d7f7b6487b87bfb4722d53fcba30129ded13
URL: https://github.com/llvm/llvm-project/commit/2a01d7f7b6487b87bfb4722d53fcba30129ded13
DIFF: https://github.com/llvm/llvm-project/commit/2a01d7f7b6487b87bfb4722d53fcba30129ded13.diff
LOG: [mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp
Differential Revision: https://reviews.llvm.org/D85449
Added:
mlir/test/Transforms/scf-if-utils.mlir
mlir/test/Transforms/scf-loop-utils.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Utils.h
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/Utils.cpp
mlir/test/lib/Transforms/TestSCFUtils.cpp
Removed:
mlir/test/Transforms/loop-utils.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
index 7f8ebd3a4260..9311d0fffbef 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -13,11 +13,15 @@
#ifndef MLIR_DIALECT_SCF_UTILS_H_
#define MLIR_DIALECT_SCF_UTILS_H_
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
+class FuncOp;
class OpBuilder;
class ValueRange;
namespace scf {
+class IfOp;
class ForOp;
class ParallelOp;
} // end namespace scf
@@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
ValueRange newYieldedValues,
bool replaceLoopResults = true);
+/// Outline the then and/or else regions of `ifOp` as follows:
+/// - if `thenFn` is not null, `thenFnName` must be specified and the `then`
+/// region is inlined into a new FuncOp that is captured by the pointer.
+/// - if `elseFn` is not null, `elseFnName` must be specified and the `else`
+/// region is inlined into a new FuncOp that is captured by the pointer.
+void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
+ StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
} // end namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 7a54ace0bf8f..341780c21c60 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRSCF
MLIRStandardOps
MLIRSupport
- )
+ MLIRTransformUtils
+)
diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index 6ae360a34abc..baa4608f1f9f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -13,7 +13,12 @@
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#include "llvm/ADT/SetVector.h"
using namespace mlir;
@@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
return newLoop;
}
+
+void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
+ StringRef thenFnName, FuncOp *elseFn,
+ StringRef elseFnName) {
+ Location loc = ifOp.getLoc();
+ MLIRContext *ctx = ifOp.getContext();
+ auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
+ assert(!funcName.empty() && "Expected function name for outlining");
+ assert(ifOrElseRegion.getBlocks().size() <= 1 &&
+ "Expected at most one block");
+
+ // Outline before current function.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
+
+ llvm::SetVector<Value> captures;
+ getUsedValuesDefinedAbove(ifOrElseRegion, captures);
+
+ ValueRange values(captures.getArrayRef());
+ FunctionType type =
+ FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
+ auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
+ b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
+ BlockAndValueMapping bvm;
+ for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
+ bvm.map(std::get<0>(it), std::get<1>(it));
+ for (Operation &op : ifOrElseRegion.front().without_terminator())
+ b.clone(op, bvm);
+
+ Operation *term = ifOrElseRegion.front().getTerminator();
+ SmallVector<Value, 4> terminatorOperands;
+ for (auto op : term->getOperands())
+ terminatorOperands.push_back(bvm.lookup(op));
+ b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
+
+ ifOrElseRegion.front().clear();
+ b.setInsertionPointToEnd(&ifOrElseRegion.front());
+ Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
+ b.create<scf::YieldOp>(loc, call->getResults());
+ return outlinedFunc;
+ };
+
+ if (thenFn && !ifOp.thenRegion().empty())
+ *thenFn = outline(ifOp.thenRegion(), thenFnName);
+ if (elseFn && !ifOp.elseRegion().empty())
+ *elseFn = outline(ifOp.elseRegion(), elseFnName);
+}
diff --git a/mlir/test/Transforms/scf-if-utils.mlir b/mlir/test/Transforms/scf-if-utils.mlir
new file mode 100644
index 000000000000..0b51d99aee6d
--- /dev/null
+++ b/mlir/test/Transforms/scf-if-utils.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
+// CHECK-NEXT: %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
+// CHECK-NEXT: return %{{.*}} : i8
+// CHECK-NEXT: }
+// CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
+// CHECK-NEXT: return %{{.*}}0 : i8
+// CHECK-NEXT: }
+// CHECK: func @outline_if_else(
+// CHECK-NEXT: %{{.*}} = scf.if %{{.*}} -> (i8) {
+// CHECK-NEXT: %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
+// CHECK-NEXT: scf.yield %{{.*}} : i8
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
+// CHECK-NEXT: scf.yield %{{.*}} : i8
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+ %r = scf.if %cond -> (i8) {
+ %r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
+ scf.yield %r : i8
+ } else {
+ scf.yield %c : i8
+ }
+ return
+}
+
+// -----
+
+// CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
+// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func @outline_if(
+// CHECK-NEXT: scf.if %{{.*}} {
+// CHECK-NEXT: call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+ scf.if %cond {
+ "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+// CHECK: func @outlined_then0() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
+// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func @outline_empty_if_else(
+// CHECK-NEXT: scf.if %{{.*}} {
+// CHECK-NEXT: call @outlined_then0() : () -> ()
+// CHECK-NEXT: } else {
+// CHECK-NEXT: call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+ scf.if %cond {
+ } else {
+ "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+ }
+ return
+}
diff --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir
similarity index 94%
rename from mlir/test/Transforms/loop-utils.mlir
rename to mlir/test/Transforms/scf-loop-utils.mlir
index 3d3dadfba179..bebeee8ce3cf 100644
--- a/mlir/test/Transforms/loop-utils.mlir
+++ b/mlir/test/Transforms/scf-loop-utils.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
// CHECK-LABEL: @hoist
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
diff --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp
index ba06bbcc8860..4b99f2550cc8 100644
--- a/mlir/test/lib/Transforms/TestSCFUtils.cpp
+++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp
@@ -21,9 +21,10 @@
using namespace mlir;
namespace {
-class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
+class TestSCFForUtilsPass
+ : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
public:
- explicit TestSCFUtilsPass() {}
+ explicit TestSCFForUtilsPass() {}
void runOnFunction() override {
FuncOp func = getFunction();
@@ -49,10 +50,31 @@ class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
loop.erase();
}
};
+
+class TestSCFIfUtilsPass
+ : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
+public:
+ explicit TestSCFIfUtilsPass() {}
+
+ void runOnFunction() override {
+ int count = 0;
+ FuncOp func = getFunction();
+ func.walk([&](scf::IfOp ifOp) {
+ auto strCount = std::to_string(count++);
+ FuncOp thenFn, elseFn;
+ OpBuilder b(ifOp);
+ outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
+ &elseFn, std::string("outlined_else") + strCount);
+ });
+ }
+};
} // end namespace
namespace mlir {
void registerTestSCFUtilsPass() {
- PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
+ PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
+ "test scf.for utils");
+ PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
+ "test scf.if utils");
}
} // namespace mlir
More information about the Mlir-commits
mailing list