[Mlir-commits] [mlir] 1fb24ce - [mlir][scf] Make whileOp builder funcs optional
Ivan Butygin
llvmlistbot at llvm.org
Thu Apr 13 09:50:27 PDT 2023
Author: Ivan Butygin
Date: 2023-04-13T18:49:04+02:00
New Revision: 1fb24cef40d26d2533b2786b21ec795cb1953133
URL: https://github.com/llvm/llvm-project/commit/1fb24cef40d26d2533b2786b21ec795cb1953133
DIFF: https://github.com/llvm/llvm-project/commit/1fb24cef40d26d2533b2786b21ec795cb1953133.diff
LOG: [mlir][scf] Make whileOp builder funcs optional
Create empty block without the terminator in this case.
Differential Revision: https://reviews.llvm.org/D148137
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b43615c9f1933..ed6d9f25e296c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3066,9 +3066,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState, TypeRange resultTypes,
ValueRange operands, BodyBuilderFn beforeBuilder,
BodyBuilderFn afterBuilder) {
- assert(beforeBuilder && "the builder callback for 'before' must be present");
- assert(afterBuilder && "the builder callback for 'after' must be present");
-
odsState.addOperands(operands);
odsState.addTypes(resultTypes);
@@ -3084,7 +3081,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
Region *beforeRegion = odsState.addRegion();
Block *beforeBlock = odsBuilder.createBlock(
beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
- beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
+ if (beforeBuilder)
+ beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
// Build after region.
SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
@@ -3092,7 +3090,9 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
Region *afterRegion = odsState.addRegion();
Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
resultTypes, afterArgLocs);
- afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
+
+ if (afterBuilder)
+ afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
OperandRange WhileOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
@@ -3811,13 +3811,11 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
return rewriter.notifyMatchFailure(op, "No results to remove");
ValueRange argsRange(newArgs);
- auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
- // Nothing
- };
Location loc = op.getLoc();
auto newWhileOp = rewriter.create<scf::WhileOp>(
- loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder);
+ loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
+ /*afterBody*/ nullptr);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();
@@ -3878,13 +3876,10 @@ struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
beforeBlock.eraseArguments(argsToRemove);
- auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
- // Nothing
- };
-
Location loc = op.getLoc();
- auto newWhileOp = rewriter.create<WhileOp>(
- loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder);
+ auto newWhileOp =
+ rewriter.create<WhileOp>(loc, op->getResultTypes(), newInits,
+ /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
Block &newBeforeBlock = newWhileOp.getBefore().front();
Block &newAfterBlock = newWhileOp.getAfter().front();
More information about the Mlir-commits
mailing list