[Mlir-commits] [mlir] 1fb24ce - [mlir][scf] Make whileOp builder funcs optional

Ivan Butygin llvmlistbot at llvm.org
Thu Apr 13 09:50:27 PDT 2023


Author: Ivan Butygin
Date: 2023-04-13T18:49:04+02:00
New Revision: 1fb24cef40d26d2533b2786b21ec795cb1953133

URL: https://github.com/llvm/llvm-project/commit/1fb24cef40d26d2533b2786b21ec795cb1953133
DIFF: https://github.com/llvm/llvm-project/commit/1fb24cef40d26d2533b2786b21ec795cb1953133.diff

LOG: [mlir][scf] Make whileOp builder funcs optional

Create empty block without the terminator in this case.

Differential Revision: https://reviews.llvm.org/D148137

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b43615c9f1933..ed6d9f25e296c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3066,9 +3066,6 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
                     ::mlir::OperationState &odsState, TypeRange resultTypes,
                     ValueRange operands, BodyBuilderFn beforeBuilder,
                     BodyBuilderFn afterBuilder) {
-  assert(beforeBuilder && "the builder callback for 'before' must be present");
-  assert(afterBuilder && "the builder callback for 'after' must be present");
-
   odsState.addOperands(operands);
   odsState.addTypes(resultTypes);
 
@@ -3084,7 +3081,8 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
   Region *beforeRegion = odsState.addRegion();
   Block *beforeBlock = odsBuilder.createBlock(
       beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
-  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
+  if (beforeBuilder)
+    beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
 
   // Build after region.
   SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
@@ -3092,7 +3090,9 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
   Region *afterRegion = odsState.addRegion();
   Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
                                              resultTypes, afterArgLocs);
-  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
+
+  if (afterBuilder)
+    afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
 OperandRange WhileOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
@@ -3811,13 +3811,11 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
       return rewriter.notifyMatchFailure(op, "No results to remove");
 
     ValueRange argsRange(newArgs);
-    auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
-      // Nothing
-    };
 
     Location loc = op.getLoc();
     auto newWhileOp = rewriter.create<scf::WhileOp>(
-        loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder);
+        loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
+        /*afterBody*/ nullptr);
     Block &newBeforeBlock = newWhileOp.getBefore().front();
     Block &newAfterBlock = newWhileOp.getAfter().front();
 
@@ -3878,13 +3876,10 @@ struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern<WhileOp> {
 
     beforeBlock.eraseArguments(argsToRemove);
 
-    auto emptyBuilder = [](OpBuilder &, Location, ValueRange) {
-      // Nothing
-    };
-
     Location loc = op.getLoc();
-    auto newWhileOp = rewriter.create<WhileOp>(
-        loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder);
+    auto newWhileOp =
+        rewriter.create<WhileOp>(loc, op->getResultTypes(), newInits,
+                                 /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
     Block &newBeforeBlock = newWhileOp.getBefore().front();
     Block &newAfterBlock = newWhileOp.getAfter().front();
 


        


More information about the Mlir-commits mailing list