[Mlir-commits] [mlir] 77533d7 - [mlir][SCF] Adding custom builder to SCF::WhileOp.
Benjamin Kramer
llvmlistbot at llvm.org
Tue Nov 15 09:16:55 PST 2022
Author: Mohammed Anany
Date: 2022-11-15T18:16:49+01:00
New Revision: 77533d79f79c6bbf9970234a9015714190218082
URL: https://github.com/llvm/llvm-project/commit/77533d79f79c6bbf9970234a9015714190218082
DIFF: https://github.com/llvm/llvm-project/commit/77533d79f79c6bbf9970234a9015714190218082.diff
LOG: [mlir][SCF] Adding custom builder to SCF::WhileOp.
This is a similar builder to the one for SCF::IfOp which allows users to pass region builders to it. Refer to the builders for IfOp.
Reviewed By: tpopp
Differential Revision: https://reviews.llvm.org/D137709
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 67a8e438d583b..2d880ac52d2c4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -935,7 +935,7 @@ def WhileOp : SCF_Op<"while",
Note that the types of region arguments need not to match with each other.
The op expects the operand types to match with argument types of the
- "before" region"; the result types to match with the trailing operand types
+ "before" region; the result types to match with the trailing operand types
of the terminator of the "before" region, and with the argument types of the
"after" region. The following scheme can be used to share the results of
some operations executed in the "before" region with the "after" region,
@@ -983,7 +983,16 @@ def WhileOp : SCF_Op<"while",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
+ let builders = [
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+ ];
+
let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ function_ref<void(OpBuilder &, Location, ValueRange)>;
+
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
ConditionOp getConditionOp();
YieldOp getYieldOp();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index ac055a71b5426..98d76f3771c8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -71,40 +71,32 @@ static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
SmallVector<Type> types = {elementTy, elementTy, elementTy};
SmallVector<Location> locations = {loc, loc, loc};
- auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
- Block *before =
- rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
- Block *after =
- rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
-
- // The conditional block of the while loop.
- {
- rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
- Value input = before->getArgument(0);
- Value zero = before->getArgument(2);
-
- Value inputNotZero = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, input, zero);
- rewriter.create<scf::ConditionOp>(loc, inputNotZero,
- before->getArguments());
- }
-
- // The body of the while loop: shift right until reaching a value of 0.
- {
- rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
- Value input = after->getArgument(0);
- Value leadingZeros = after->getArgument(1);
-
- auto one =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
- auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
- auto leadingZerosMinusOne =
- rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
-
- rewriter.create<scf::YieldOp>(
- loc,
- ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
- }
+ auto whileOp = rewriter.create<scf::WhileOp>(
+ loc, types, operands,
+ [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
+ // The conditional block of the while loop.
+ Value input = args[0];
+ Value zero = args[2];
+
+ Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, input, zero);
+ beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
+ },
+ [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
+ // The body of the while loop: shift right until reaching a value of 0.
+ Value input = args[0];
+ Value leadingZeros = args[1];
+
+ auto one = afterBuilder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 1));
+ auto shifted =
+ afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
+ auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
+ loc, resultTy, leadingZeros, one);
+
+ afterBuilder.create<scf::YieldOp>(
+ loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
+ });
rewriter.setInsertionPointAfter(whileOp);
rewriter.replaceOp(op, whileOp->getResult(1));
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b39edc0305002..118452aae10b3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2669,6 +2669,34 @@ LogicalResult ReduceReturnOp::verify() {
// WhileOp
//===----------------------------------------------------------------------===//
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, TypeRange resultTypes,
+ ValueRange operands, BodyBuilderFn beforeBuilder,
+ BodyBuilderFn afterBuilder) {
+ assert(beforeBuilder && "the builder callback for 'before' must be present");
+ assert(afterBuilder && "the builder callback for 'after' must be present");
+
+ odsState.addOperands(operands);
+ odsState.addTypes(resultTypes);
+
+ OpBuilder::InsertionGuard guard(odsBuilder);
+
+ SmallVector<Location, 4> blockArgLocs;
+ for (Value operand : operands) {
+ blockArgLocs.push_back(operand.getLoc());
+ }
+
+ Region *beforeRegion = odsState.addRegion();
+ Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
+ resultTypes, blockArgLocs);
+ beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
+
+ Region *afterRegion = odsState.addRegion();
+ Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
+ resultTypes, blockArgLocs);
+ afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
+}
+
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");
More information about the Mlir-commits
mailing list