[Mlir-commits] [mlir] cadb7cc - [mlir] SCF: provide function_ref builders for IfOp

Alex Zinenko llvmlistbot at llvm.org
Wed May 27 07:13:06 PDT 2020


Author: Alex Zinenko
Date: 2020-05-27T16:12:58+02:00
New Revision: cadb7ccf2cebcaa2d546db77223bde3d69a162af

URL: https://github.com/llvm/llvm-project/commit/cadb7ccf2cebcaa2d546db77223bde3d69a162af
DIFF: https://github.com/llvm/llvm-project/commit/cadb7ccf2cebcaa2d546db77223bde3d69a162af.diff

LOG: [mlir] SCF: provide function_ref builders for IfOp

Now that OpBuilder is available in `build` functions, it becomes possible to
populate the "then" and "else" regions directly when building the "if"
operation. This is desirable in more structured forms of builders, especially
in when conditionals are mixed with loops. Provide new `build` APIs taking
callbacks for body constructors, similarly to scf::ForOp, and replace more
clunky edsc::BlockBuilder uses with these. The original APIs remain available
and go through the new implementation.

Differential Revision: https://reviews.llvm.org/D80527

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
    mlir/include/mlir/Dialect/SCF/SCF.h
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/SCF/EDSC/Builders.cpp
    mlir/lib/Dialect/SCF/SCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
index fa72bd623b25..607ea439d63a 100644
--- a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
@@ -82,6 +82,16 @@ scf::ValueVector loopNestBuilder(
     Value lb, Value ub, Value step, ValueRange iterArgInitValues,
     function_ref<scf::ValueVector(Value, ValueRange)> fun = nullptr);
 
+/// Adapters for building if conditions using the builder and the location
+/// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted
+/// if the condition should not have an 'else' part.
+ValueRange
+conditionBuilder(TypeRange results, Value condition,
+                 function_ref<scf::ValueVector()> thenBody,
+                 function_ref<scf::ValueVector()> elseBody = nullptr);
+ValueRange conditionBuilder(Value condition, function_ref<void()> thenBody,
+                            function_ref<void()> elseBody = nullptr);
+
 } // namespace edsc
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h
index 570e8a31bce4..3974b58cbfbb 100644
--- a/mlir/include/mlir/Dialect/SCF/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/SCF.h
@@ -24,6 +24,8 @@
 namespace mlir {
 namespace scf {
 
+void buildTerminatedBody(OpBuilder &builder, Location loc);
+
 #include "mlir/Dialect/SCF/SCFOpsDialect.h.inc"
 
 #define GET_OP_CLASSES

diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 84f164584003..a57d862d44ff 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -238,7 +238,18 @@ def IfOp : SCF_Op<"if",
     OpBuilder<"OpBuilder &builder, OperationState &result, "
               "Value cond, bool withElseRegion">,
     OpBuilder<"OpBuilder &builder, OperationState &result, "
-              "TypeRange resultTypes, Value cond, bool withElseRegion">
+              "TypeRange resultTypes, Value cond, bool withElseRegion">,
+    OpBuilder<
+        "OpBuilder &builder, OperationState &result, TypeRange resultTypes, "
+        "Value cond, "
+        "function_ref<void(OpBuilder &, Location)> thenBuilder "
+        "    = buildTerminatedBody, "
+        "function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">,
+    OpBuilder<
+        "OpBuilder &builder, OperationState &result, Value cond, "
+        "function_ref<void(OpBuilder &, Location)> thenBuilder "
+        "    = buildTerminatedBody, "
+        "function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 9868a14c2165..8c72800819a5 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -235,39 +235,38 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
       SmallVector<Type, 1> resultType;
       if (options.unroll)
         resultType.push_back(vectorType);
-      auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
-          ScopedContext::getLocation(), resultType, inBoundsCondition,
-          /*withElseRegion=*/true);
-
-      // 3.a. If in-bounds, progressively lower to a 1-D transfer read.
-      BlockBuilder(&ifOp.thenRegion().front(), Append())([&] {
-        Value vector = load1DVector(majorIvsPlusOffsets);
-        // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
-        // aggregate. We must yield and merge with the `else` branch.
-        if (options.unroll) {
-          vector = vector_insert(vector, result, majorIvs);
-          (loop_yield(vector));
-          return;
-        }
-        // 3.a.ii. Otherwise, just go through the temporary `alloc`.
-        std_store(vector, alloc, majorIvs);
-      });
-
-      // 3.b. If not in-bounds, splat a 1-D vector.
-      BlockBuilder(&ifOp.elseRegion().front(), Append())([&] {
-        Value vector = std_splat(minorVectorType, xferOp.padding());
-        // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
-        // aggregate. We must yield and merge with the `then` branch.
-        if (options.unroll) {
-          vector = vector_insert(vector, result, majorIvs);
-          (loop_yield(vector));
-          return;
-        }
-        // 3.b.ii. Otherwise, just go through the temporary `alloc`.
-        std_store(vector, alloc, majorIvs);
-      });
+
+      // 3. If in-bounds, progressively lower to a 1-D transfer read, otherwise
+      // splat a 1-D vector.
+      ValueRange ifResults = conditionBuilder(
+          resultType, inBoundsCondition,
+          [&]() -> scf::ValueVector {
+            Value vector = load1DVector(majorIvsPlusOffsets);
+            // 3.a. If `options.unroll` is true, insert the 1-D vector in the
+            // aggregate. We must yield and merge with the `else` branch.
+            if (options.unroll) {
+              vector = vector_insert(vector, result, majorIvs);
+              return {vector};
+            }
+            // 3.b. Otherwise, just go through the temporary `alloc`.
+            std_store(vector, alloc, majorIvs);
+            return {};
+          },
+          [&]() -> scf::ValueVector {
+            Value vector = std_splat(minorVectorType, xferOp.padding());
+            // 3.c. If `options.unroll` is true, insert the 1-D vector in the
+            // aggregate. We must yield and merge with the `then` branch.
+            if (options.unroll) {
+              vector = vector_insert(vector, result, majorIvs);
+              return {vector};
+            }
+            // 3.d. Otherwise, just go through the temporary `alloc`.
+            std_store(vector, alloc, majorIvs);
+            return {};
+          });
+
       if (!resultType.empty())
-        result = *ifOp.results().begin();
+        result = *ifResults.begin();
     } else {
       // 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read.
       Value loaded1D = load1DVector(majorIvsPlusOffsets);
@@ -336,11 +335,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
     if (inBoundsCondition) {
       // 2.a. If the condition is not null, we need an IfOp, to write
       // conditionally. Progressively lower to a 1-D transfer write.
-      auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
-          ScopedContext::getLocation(), TypeRange{}, inBoundsCondition,
-          /*withElseRegion=*/false);
-      BlockBuilder(&ifOp.thenRegion().front(),
-                   Append())([&] { emitTransferWrite(majorIvsPlusOffsets); });
+      conditionBuilder(inBoundsCondition,
+                       [&] { emitTransferWrite(majorIvsPlusOffsets); });
     } else {
       // 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write.
       emitTransferWrite(majorIvsPlusOffsets);

diff  --git a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
index 4ce701c1d7f9..090c72fcd91f 100644
--- a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp
@@ -159,3 +159,51 @@ mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
                                 iterArgInitValues.end());
       });
 }
+
+static std::function<void(OpBuilder &, Location)>
+wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
+  (void)expectedTypes;
+  return [=](OpBuilder &builder, Location loc) {
+    ScopedContext context(builder, loc);
+    scf::ValueVector returned = body();
+    assert(ValueRange(returned).getTypes() == expectedTypes &&
+           "'if' body builder returned values of unexpected type");
+    builder.create<scf::YieldOp>(loc, returned);
+  };
+}
+
+ValueRange
+mlir::edsc::conditionBuilder(TypeRange results, Value condition,
+                             function_ref<scf::ValueVector()> thenBody,
+                             function_ref<scf::ValueVector()> elseBody) {
+  assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
+  assert(thenBody && "thenBody is mandatory");
+
+  auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
+      ScopedContext::getLocation(), results, condition,
+      wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
+  return ifOp.getResults();
+}
+
+static std::function<void(OpBuilder &, Location)>
+wrapZeroResultIfBody(function_ref<void()> body) {
+  return [=](OpBuilder &builder, Location loc) {
+    ScopedContext context(builder, loc);
+    body();
+    builder.create<scf::YieldOp>(loc);
+  };
+}
+
+ValueRange mlir::edsc::conditionBuilder(Value condition,
+                                        function_ref<void()> thenBody,
+                                        function_ref<void()> elseBody) {
+  assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
+  assert(thenBody && "thenBody is mandatory");
+
+  ScopedContext::getBuilderRef().create<scf::IfOp>(
+      ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
+      elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
+                     wrapZeroResultIfBody(elseBody))
+               : llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
+  return {};
+}

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index edcb5aafbe4e..e7c890c17841 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -35,6 +35,11 @@ SCFDialect::SCFDialect(MLIRContext *context)
       >();
 }
 
+/// Default callback for IfOp builders. Inserts a yield without arguments.
+void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
+  builder.create<scf::YieldOp>(loc);
+}
+
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
@@ -338,20 +343,43 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
 
 void IfOp::build(OpBuilder &builder, OperationState &result,
                  TypeRange resultTypes, Value cond, bool withElseRegion) {
+  auto addTerminator = [&](OpBuilder &nested, Location loc) {
+    if (resultTypes.empty())
+      IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
+                             loc);
+  };
+
+  build(builder, result, resultTypes, cond, addTerminator,
+        withElseRegion ? addTerminator
+                       : function_ref<void(OpBuilder &, Location)>());
+}
+
+void IfOp::build(OpBuilder &builder, OperationState &result,
+                 TypeRange resultTypes, Value cond,
+                 function_ref<void(OpBuilder &, Location)> thenBuilder,
+                 function_ref<void(OpBuilder &, Location)> elseBuilder) {
+  assert(thenBuilder && "the builder callback for 'then' must be present");
+
   result.addOperands(cond);
   result.addTypes(resultTypes);
 
+  OpBuilder::InsertionGuard guard(builder);
   Region *thenRegion = result.addRegion();
-  thenRegion->push_back(new Block());
-  if (resultTypes.empty())
-    IfOp::ensureTerminator(*thenRegion, builder, result.location);
+  builder.createBlock(thenRegion);
+  thenBuilder(builder, result.location);
 
   Region *elseRegion = result.addRegion();
-  if (withElseRegion) {
-    elseRegion->push_back(new Block());
-    if (resultTypes.empty())
-      IfOp::ensureTerminator(*elseRegion, builder, result.location);
-  }
+  if (!elseBuilder)
+    return;
+
+  builder.createBlock(elseRegion);
+  elseBuilder(builder, result.location);
+}
+
+void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
+                 function_ref<void(OpBuilder &, Location)> thenBuilder,
+                 function_ref<void(OpBuilder &, Location)> elseBuilder) {
+  build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
 }
 
 static LogicalResult verify(IfOp op) {


        


More information about the Mlir-commits mailing list