[Mlir-commits] [mlir] 3adced3 - [mlir] Introduce callback-based builders to SCF Parallel and Reduce ops

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 16 11:51:40 PDT 2020


Author: Alex Zinenko
Date: 2020-06-16T20:51:32+02:00
New Revision: 3adced3494d07ac6072a9336cb8ae3802f660c7a

URL: https://github.com/llvm/llvm-project/commit/3adced3494d07ac6072a9336cb8ae3802f660c7a
DIFF: https://github.com/llvm/llvm-project/commit/3adced3494d07ac6072a9336cb8ae3802f660c7a.diff

LOG: [mlir] Introduce callback-based builders to SCF Parallel and Reduce ops

Similarly to `scf::ForOp`, introduce additional `function_ref` arguments to
`::build` functions of SCF `ParallelOp` and `ReduceOp`. The provided functions
will be called to construct the body of the respective operations while
constructing the operation itself. Exercise them in LoopUtils.

Differential Revision: https://reviews.llvm.org/D81872

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index a57d862d44ff..420583161cbe 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -328,7 +328,15 @@ def ParallelOp : SCF_Op<"parallel",
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
               "ValueRange lowerBounds, ValueRange upperBounds, "
-              "ValueRange steps, ValueRange initVals = {}">,
+              "ValueRange steps, ValueRange initVals, "
+              "function_ref<void (OpBuilder &, Location, "
+                                 "ValueRange, ValueRange)>"
+              "  bodyBuilderFn = nullptr">,
+    OpBuilder<"OpBuilder &builder, OperationState &result, "
+              "ValueRange lowerBounds, ValueRange upperBounds, "
+              "ValueRange steps, "
+              "function_ref<void (OpBuilder &, Location, ValueRange)>"
+              "  bodyBuilderFn = nullptr">,
   ];
 
   let extraClassDeclaration = [{
@@ -380,7 +388,9 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
-              "Value operand">
+              "Value operand, "
+              "function_ref<void (OpBuilder &, Location, Value, Value)>"
+              "  bodyBuilderFn = nullptr">
   ];
 
   let arguments = (ins AnyType:$operand);

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index b883bad36f10..f980cdb96a6a 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -495,25 +495,56 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
 // ParallelOp
 //===----------------------------------------------------------------------===//
 
-void ParallelOp::build(OpBuilder &builder, OperationState &result,
-                       ValueRange lbs, ValueRange ubs, ValueRange steps,
-                       ValueRange initVals) {
-  result.addOperands(lbs);
-  result.addOperands(ubs);
+void ParallelOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
+    ValueRange upperBounds, ValueRange steps, ValueRange initVals,
+    function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+        bodyBuilderFn) {
+  result.addOperands(lowerBounds);
+  result.addOperands(upperBounds);
   result.addOperands(steps);
   result.addOperands(initVals);
   result.addAttribute(
       ParallelOp::getOperandSegmentSizeAttr(),
-      builder.getI32VectorAttr({static_cast<int32_t>(lbs.size()),
-                                static_cast<int32_t>(ubs.size()),
+      builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
+                                static_cast<int32_t>(upperBounds.size()),
                                 static_cast<int32_t>(steps.size()),
                                 static_cast<int32_t>(initVals.size())}));
+  result.addTypes(initVals.getTypes());
+
+  OpBuilder::InsertionGuard guard(builder);
+  unsigned numIVs = steps.size();
+  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
   Region *bodyRegion = result.addRegion();
+  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
+
+  if (bodyBuilderFn) {
+    builder.setInsertionPointToStart(bodyBlock);
+    bodyBuilderFn(builder, result.location,
+                  bodyBlock->getArguments().take_front(numIVs),
+                  bodyBlock->getArguments().drop_front(numIVs));
+  }
   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
-  for (size_t i = 0, e = steps.size(); i < e; ++i)
-    bodyRegion->front().addArgument(builder.getIndexType());
-  for (Value init : initVals)
-    result.addTypes(init.getType());
+}
+
+void ParallelOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
+    ValueRange upperBounds, ValueRange steps,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
+  // we don't capture a reference to a temporary by constructing the lambda at
+  // function level.
+  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
+                                           Location nestedLoc, ValueRange ivs,
+                                           ValueRange) {
+    bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
+  };
+  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
+  if (bodyBuilderFn)
+    wrapper = wrappedBuilderFn;
+
+  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
+        wrapper);
 }
 
 static LogicalResult verify(ParallelOp op) {
@@ -679,15 +710,18 @@ ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
 // ReduceOp
 //===----------------------------------------------------------------------===//
 
-void ReduceOp::build(OpBuilder &builder, OperationState &result,
-                     Value operand) {
+void ReduceOp::build(
+    OpBuilder &builder, OperationState &result, Value operand,
+    function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
   auto type = operand.getType();
   result.addOperands(operand);
-  Region *bodyRegion = result.addRegion();
 
-  Block *b = new Block();
-  b->addArguments(ArrayRef<Type>{type, type});
-  bodyRegion->getBlocks().insert(bodyRegion->end(), b);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *bodyRegion = result.addRegion();
+  Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type});
+  if (bodyBuilderFn)
+    bodyBuilderFn(builder, result.location, body->getArgument(0),
+                  body->getArgument(1));
 }
 
 static LogicalResult verify(ReduceOp op) {

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 2d68f283e381..58e2d9b42043 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1472,33 +1472,34 @@ void mlir::collapseParallelLoops(
   // value. The remainders then determine based on that range, which iteration
   // of the original induction value this represents. This is a normalized value
   // that is un-normalized already by the previous logic.
-  auto newPloop = outsideBuilder.create<scf::ParallelOp>(loc, lowerBounds,
-                                                         upperBounds, steps);
-  OpBuilder insideBuilder(newPloop.region());
-  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
-    Value previous = newPloop.getBody()->getArgument(i);
-    unsigned numberCombinedDimensions = combinedDimensions[i].size();
-    // Iterate over all except the last induction value.
-    for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
-      unsigned idx = combinedDimensions[i][j];
-
-      // Determine the current induction value's current loop iteration
-      Value iv = insideBuilder.create<SignedRemIOp>(loc, previous,
-                                                    normalizedUpperBounds[idx]);
-      replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
-                                 loops.region());
-
-      // Remove the effect of the current induction value to prepare for the
-      // next value.
-      previous = insideBuilder.create<SignedDivIOp>(
-          loc, previous, normalizedUpperBounds[idx + 1]);
-    }
+  auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+      loc, lowerBounds, upperBounds, steps,
+      [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
+        for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
+          Value previous = ploopIVs[i];
+          unsigned numberCombinedDimensions = combinedDimensions[i].size();
+          // Iterate over all except the last induction value.
+          for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
+            unsigned idx = combinedDimensions[i][j];
+
+            // Determine the current induction value's current loop iteration
+            Value iv = insideBuilder.create<SignedRemIOp>(
+                loc, previous, normalizedUpperBounds[idx]);
+            replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
+                                       loops.region());
+
+            // Remove the effect of the current induction value to prepare for
+            // the next value.
+            previous = insideBuilder.create<SignedDivIOp>(
+                loc, previous, normalizedUpperBounds[idx + 1]);
+          }
 
-    // The final induction value is just the remaining value.
-    unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
-    replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), previous,
-                               loops.region());
-  }
+          // The final induction value is just the remaining value.
+          unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
+          replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
+                                     previous, loops.region());
+        }
+      });
 
   // Replace the old loop with the new loop.
   loops.getBody()->back().erase();


        


More information about the Mlir-commits mailing list