[Mlir-commits] [mlir] 2a01d7f - [mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 7 11:50:29 PDT 2020


Author: Nicolas Vasilache
Date: 2020-08-07T14:49:49-04:00
New Revision: 2a01d7f7b6487b87bfb4722d53fcba30129ded13

URL: https://github.com/llvm/llvm-project/commit/2a01d7f7b6487b87bfb4722d53fcba30129ded13
DIFF: https://github.com/llvm/llvm-project/commit/2a01d7f7b6487b87bfb4722d53fcba30129ded13.diff

LOG: [mlir][SCF] Add utility to outline the then and else branches of an scf.IfOp

Differential Revision: https://reviews.llvm.org/D85449

Added: 
    mlir/test/Transforms/scf-if-utils.mlir
    mlir/test/Transforms/scf-loop-utils.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Utils.h
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/lib/Dialect/SCF/Transforms/Utils.cpp
    mlir/test/lib/Transforms/TestSCFUtils.cpp

Removed: 
    mlir/test/Transforms/loop-utils.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h
index 7f8ebd3a4260..9311d0fffbef 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils.h
@@ -13,11 +13,15 @@
 #ifndef MLIR_DIALECT_SCF_UTILS_H_
 #define MLIR_DIALECT_SCF_UTILS_H_
 
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
+class FuncOp;
 class OpBuilder;
 class ValueRange;
 
 namespace scf {
+class IfOp;
 class ForOp;
 class ParallelOp;
 } // end namespace scf
@@ -46,5 +50,12 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
                               ValueRange newYieldedValues,
                               bool replaceLoopResults = true);
 
+/// Outline the then and/or else regions of `ifOp` as follows:
+///  - if `thenFn` is not null, `thenFnName` must be specified and the `then`
+///    region is inlined into a new FuncOp that is captured by the pointer.
+///  - if `elseFn` is not null, `elseFnName` must be specified and the `else`
+///    region is inlined into a new FuncOp that is captured by the pointer.
+void outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
+                 StringRef thenFnName, FuncOp *elseFn, StringRef elseFnName);
 } // end namespace mlir
 #endif // MLIR_DIALECT_SCF_UTILS_H_

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 7a54ace0bf8f..341780c21c60 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -17,4 +17,5 @@ add_mlir_dialect_library(MLIRSCFTransforms
   MLIRSCF
   MLIRStandardOps
   MLIRSupport
-  )
+  MLIRTransformUtils
+)

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index 6ae360a34abc..baa4608f1f9f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -13,7 +13,12 @@
 #include "mlir/Dialect/SCF/Utils.h"
 
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
 
@@ -71,3 +76,50 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
 
   return newLoop;
 }
+
+void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
+                       StringRef thenFnName, FuncOp *elseFn,
+                       StringRef elseFnName) {
+  Location loc = ifOp.getLoc();
+  MLIRContext *ctx = ifOp.getContext();
+  auto outline = [&](Region &ifOrElseRegion, StringRef funcName) {
+    assert(!funcName.empty() && "Expected function name for outlining");
+    assert(ifOrElseRegion.getBlocks().size() <= 1 &&
+           "Expected at most one block");
+
+    // Outline before current function.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(ifOp.getParentOfType<FuncOp>());
+
+    llvm::SetVector<Value> captures;
+    getUsedValuesDefinedAbove(ifOrElseRegion, captures);
+
+    ValueRange values(captures.getArrayRef());
+    FunctionType type =
+        FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
+    auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
+    b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
+    BlockAndValueMapping bvm;
+    for (auto it : llvm::zip(values, outlinedFunc.getArguments()))
+      bvm.map(std::get<0>(it), std::get<1>(it));
+    for (Operation &op : ifOrElseRegion.front().without_terminator())
+      b.clone(op, bvm);
+
+    Operation *term = ifOrElseRegion.front().getTerminator();
+    SmallVector<Value, 4> terminatorOperands;
+    for (auto op : term->getOperands())
+      terminatorOperands.push_back(bvm.lookup(op));
+    b.create<ReturnOp>(loc, term->getResultTypes(), terminatorOperands);
+
+    ifOrElseRegion.front().clear();
+    b.setInsertionPointToEnd(&ifOrElseRegion.front());
+    Operation *call = b.create<CallOp>(loc, outlinedFunc, values);
+    b.create<scf::YieldOp>(loc, call->getResults());
+    return outlinedFunc;
+  };
+
+  if (thenFn && !ifOp.thenRegion().empty())
+    *thenFn = outline(ifOp.thenRegion(), thenFnName);
+  if (elseFn && !ifOp.elseRegion().empty())
+    *elseFn = outline(ifOp.elseRegion(), elseFnName);
+}

diff  --git a/mlir/test/Transforms/scf-if-utils.mlir b/mlir/test/Transforms/scf-if-utils.mlir
new file mode 100644
index 000000000000..0b51d99aee6d
--- /dev/null
+++ b/mlir/test/Transforms/scf-if-utils.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-if-utils -split-input-file %s | FileCheck %s
+
+// -----
+
+//      CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) -> i8 {
+// CHECK-NEXT:   %{{.*}} = "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
+// CHECK-NEXT:   return %{{.*}} : i8
+// CHECK-NEXT: }
+//      CHECK: func @outlined_else0(%{{.*}}: i8) -> i8 {
+// CHECK-NEXT:   return %{{.*}}0 : i8
+// CHECK-NEXT: }
+//      CHECK: func @outline_if_else(
+// CHECK-NEXT:   %{{.*}} = scf.if %{{.*}} -> (i8) {
+// CHECK-NEXT:     %{{.*}} = call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> i8
+// CHECK-NEXT:     scf.yield %{{.*}} : i8
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     %{{.*}} = call @outlined_else0(%{{.*}}) : (i8) -> i8
+// CHECK-NEXT:     scf.yield %{{.*}} : i8
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+func @outline_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+  %r = scf.if %cond -> (i8) {
+    %r = "some_op"(%cond, %b) : (i1, memref<?xf32>) -> (i8)
+    scf.yield %r : i8
+  } else {
+    scf.yield %c : i8
+  }
+  return
+}
+
+// -----
+
+//      CHECK: func @outlined_then0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
+// CHECK-NEXT:   "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+//      CHECK: func @outline_if(
+// CHECK-NEXT:   scf.if %{{.*}} {
+// CHECK-NEXT:     call @outlined_then0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+func @outline_if(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+  scf.if %cond {
+    "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+    scf.yield
+  }
+  return
+}
+
+// -----
+
+//      CHECK: func @outlined_then0() {
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+//      CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: memref<?xf32>) {
+// CHECK-NEXT:   "some_op"(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+//      CHECK: func @outline_empty_if_else(
+// CHECK-NEXT:   scf.if %{{.*}} {
+// CHECK-NEXT:     call @outlined_then0() : () -> ()
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, memref<?xf32>) -> ()
+// CHECK-NEXT:   }
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
+  scf.if %cond {
+  } else {
+    "some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
+  }
+  return
+}

diff  --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir
similarity index 94%
rename from mlir/test/Transforms/loop-utils.mlir
rename to mlir/test/Transforms/scf-loop-utils.mlir
index 3d3dadfba179..bebeee8ce3cf 100644
--- a/mlir/test/Transforms/loop-utils.mlir
+++ b/mlir/test/Transforms/scf-loop-utils.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
 
 // CHECK-LABEL: @hoist
 //  CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,

diff  --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp
index ba06bbcc8860..4b99f2550cc8 100644
--- a/mlir/test/lib/Transforms/TestSCFUtils.cpp
+++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp
@@ -21,9 +21,10 @@
 using namespace mlir;
 
 namespace {
-class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
+class TestSCFForUtilsPass
+    : public PassWrapper<TestSCFForUtilsPass, FunctionPass> {
 public:
-  explicit TestSCFUtilsPass() {}
+  explicit TestSCFForUtilsPass() {}
 
   void runOnFunction() override {
     FuncOp func = getFunction();
@@ -49,10 +50,31 @@ class TestSCFUtilsPass : public PassWrapper<TestSCFUtilsPass, FunctionPass> {
       loop.erase();
   }
 };
+
+class TestSCFIfUtilsPass
+    : public PassWrapper<TestSCFIfUtilsPass, FunctionPass> {
+public:
+  explicit TestSCFIfUtilsPass() {}
+
+  void runOnFunction() override {
+    int count = 0;
+    FuncOp func = getFunction();
+    func.walk([&](scf::IfOp ifOp) {
+      auto strCount = std::to_string(count++);
+      FuncOp thenFn, elseFn;
+      OpBuilder b(ifOp);
+      outlineIfOp(b, ifOp, &thenFn, std::string("outlined_then") + strCount,
+                  &elseFn, std::string("outlined_else") + strCount);
+    });
+  }
+};
 } // end namespace
 
 namespace mlir {
 void registerTestSCFUtilsPass() {
-  PassRegistration<TestSCFUtilsPass>("test-scf-utils", "test scf utils");
+  PassRegistration<TestSCFForUtilsPass>("test-scf-for-utils",
+                                        "test scf.for utils");
+  PassRegistration<TestSCFIfUtilsPass>("test-scf-if-utils",
+                                       "test scf.if utils");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list