[Mlir-commits] [mlir] 3adced3 - [mlir] Introduce callback-based builders to SCF Parallel and Reduce ops
Alex Zinenko
llvmlistbot at llvm.org
Tue Jun 16 11:51:40 PDT 2020
Author: Alex Zinenko
Date: 2020-06-16T20:51:32+02:00
New Revision: 3adced3494d07ac6072a9336cb8ae3802f660c7a
URL: https://github.com/llvm/llvm-project/commit/3adced3494d07ac6072a9336cb8ae3802f660c7a
DIFF: https://github.com/llvm/llvm-project/commit/3adced3494d07ac6072a9336cb8ae3802f660c7a.diff
LOG: [mlir] Introduce callback-based builders to SCF Parallel and Reduce ops
Similarly to `scf::ForOp`, introduce additional `function_ref` arguments to
`::build` functions of SCF `ParallelOp` and `ReduceOp`. The provided functions
will be called to construct the body of the respective operations while
constructing the operation itself. Exercise them in LoopUtils.
Differential Revision: https://reviews.llvm.org/D81872
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index a57d862d44ff..420583161cbe 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -328,7 +328,15 @@ def ParallelOp : SCF_Op<"parallel",
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
"ValueRange lowerBounds, ValueRange upperBounds, "
- "ValueRange steps, ValueRange initVals = {}">,
+ "ValueRange steps, ValueRange initVals, "
+ "function_ref<void (OpBuilder &, Location, "
+ "ValueRange, ValueRange)>"
+ " bodyBuilderFn = nullptr">,
+ OpBuilder<"OpBuilder &builder, OperationState &result, "
+ "ValueRange lowerBounds, ValueRange upperBounds, "
+ "ValueRange steps, "
+ "function_ref<void (OpBuilder &, Location, ValueRange)>"
+ " bodyBuilderFn = nullptr">,
];
let extraClassDeclaration = [{
@@ -380,7 +388,9 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> {
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
- "Value operand">
+ "Value operand, "
+ "function_ref<void (OpBuilder &, Location, Value, Value)>"
+ " bodyBuilderFn = nullptr">
];
let arguments = (ins AnyType:$operand);
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index b883bad36f10..f980cdb96a6a 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -495,25 +495,56 @@ void IfOp::getSuccessorRegions(Optional<unsigned> index,
// ParallelOp
//===----------------------------------------------------------------------===//
-void ParallelOp::build(OpBuilder &builder, OperationState &result,
- ValueRange lbs, ValueRange ubs, ValueRange steps,
- ValueRange initVals) {
- result.addOperands(lbs);
- result.addOperands(ubs);
+void ParallelOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
+ ValueRange upperBounds, ValueRange steps, ValueRange initVals,
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+ bodyBuilderFn) {
+ result.addOperands(lowerBounds);
+ result.addOperands(upperBounds);
result.addOperands(steps);
result.addOperands(initVals);
result.addAttribute(
ParallelOp::getOperandSegmentSizeAttr(),
- builder.getI32VectorAttr({static_cast<int32_t>(lbs.size()),
- static_cast<int32_t>(ubs.size()),
+ builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
+ static_cast<int32_t>(upperBounds.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(initVals.size())}));
+ result.addTypes(initVals.getTypes());
+
+ OpBuilder::InsertionGuard guard(builder);
+ unsigned numIVs = steps.size();
+ SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
Region *bodyRegion = result.addRegion();
+ Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
+
+ if (bodyBuilderFn) {
+ builder.setInsertionPointToStart(bodyBlock);
+ bodyBuilderFn(builder, result.location,
+ bodyBlock->getArguments().take_front(numIVs),
+ bodyBlock->getArguments().drop_front(numIVs));
+ }
ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
- for (size_t i = 0, e = steps.size(); i < e; ++i)
- bodyRegion->front().addArgument(builder.getIndexType());
- for (Value init : initVals)
- result.addTypes(init.getType());
+}
+
+void ParallelOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
+ ValueRange upperBounds, ValueRange steps,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
+ // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
+ // we don't capture a reference to a temporary by constructing the lambda at
+ // function level.
+ auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
+ Location nestedLoc, ValueRange ivs,
+ ValueRange) {
+ bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
+ };
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
+ if (bodyBuilderFn)
+ wrapper = wrappedBuilderFn;
+
+ build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
+ wrapper);
}
static LogicalResult verify(ParallelOp op) {
@@ -679,15 +710,18 @@ ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
// ReduceOp
//===----------------------------------------------------------------------===//
-void ReduceOp::build(OpBuilder &builder, OperationState &result,
- Value operand) {
+void ReduceOp::build(
+ OpBuilder &builder, OperationState &result, Value operand,
+ function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
auto type = operand.getType();
result.addOperands(operand);
- Region *bodyRegion = result.addRegion();
- Block *b = new Block();
- b->addArguments(ArrayRef<Type>{type, type});
- bodyRegion->getBlocks().insert(bodyRegion->end(), b);
+ OpBuilder::InsertionGuard guard(builder);
+ Region *bodyRegion = result.addRegion();
+ Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type});
+ if (bodyBuilderFn)
+ bodyBuilderFn(builder, result.location, body->getArgument(0),
+ body->getArgument(1));
}
static LogicalResult verify(ReduceOp op) {
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 2d68f283e381..58e2d9b42043 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1472,33 +1472,34 @@ void mlir::collapseParallelLoops(
// value. The remainders then determine based on that range, which iteration
// of the original induction value this represents. This is a normalized value
// that is un-normalized already by the previous logic.
- auto newPloop = outsideBuilder.create<scf::ParallelOp>(loc, lowerBounds,
- upperBounds, steps);
- OpBuilder insideBuilder(newPloop.region());
- for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
- Value previous = newPloop.getBody()->getArgument(i);
- unsigned numberCombinedDimensions = combinedDimensions[i].size();
- // Iterate over all except the last induction value.
- for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
- unsigned idx = combinedDimensions[i][j];
-
- // Determine the current induction value's current loop iteration
- Value iv = insideBuilder.create<SignedRemIOp>(loc, previous,
- normalizedUpperBounds[idx]);
- replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
- loops.region());
-
- // Remove the effect of the current induction value to prepare for the
- // next value.
- previous = insideBuilder.create<SignedDivIOp>(
- loc, previous, normalizedUpperBounds[idx + 1]);
- }
+ auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+ loc, lowerBounds, upperBounds, steps,
+ [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
+ for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
+ Value previous = ploopIVs[i];
+ unsigned numberCombinedDimensions = combinedDimensions[i].size();
+ // Iterate over all except the last induction value.
+ for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) {
+ unsigned idx = combinedDimensions[i][j];
+
+ // Determine the current induction value's current loop iteration
+ Value iv = insideBuilder.create<SignedRemIOp>(
+ loc, previous, normalizedUpperBounds[idx]);
+ replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
+ loops.region());
+
+ // Remove the effect of the current induction value to prepare for
+ // the next value.
+ previous = insideBuilder.create<SignedDivIOp>(
+ loc, previous, normalizedUpperBounds[idx + 1]);
+ }
- // The final induction value is just the remaining value.
- unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
- replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), previous,
- loops.region());
- }
+ // The final induction value is just the remaining value.
+ unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1];
+ replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
+ previous, loops.region());
+ }
+ });
// Replace the old loop with the new loop.
loops.getBody()->back().erase();
More information about the Mlir-commits
mailing list