[Mlir-commits] [mlir] [mlir] Use `OpBuilder::createBlock` in op builders and patterns (PR #82770)

Matthias Springer llvmlistbot at llvm.org
Fri Feb 23 07:26:15 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82770

>From ac01c30980dcc031d03b802d9046299abf16c781 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 23 Feb 2024 15:24:46 +0000
Subject: [PATCH] [mlir] Use `OpBuilder::createBlock` in op builders and
 patterns

When creating a new block in (conversion) rewrite patterns, `OpBuilder::createBlock` must be used. Otherwise, no `notifyBlockInserted` notification is sent to the listener.

Note: The dialect conversion relies on listener notifications to keep track of IR modifications. Creating blocks without the builder API can lead to memory leaks during rollback.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  2 +-
 .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td   |  4 ++--
 .../mlir/Interfaces/FunctionInterfaces.td     |  4 ++--
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    |  2 +-
 .../ControlFlowToSCF/ControlFlowToSCF.cpp     |  4 +---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |  4 ++--
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  4 +---
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 11 ++++-----
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 24 +++++++++++--------
 mlir/lib/Dialect/Async/IR/Async.cpp           | 17 +++++--------
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 10 ++++----
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    |  7 +++---
 .../Linalg/Transforms/DropUnitDims.cpp        |  6 ++---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  6 ++---
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  9 ++++---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  3 ++-
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  | 18 +++++++-------
 mlir/lib/Dialect/Shape/IR/Shape.cpp           | 16 ++++++-------
 .../Transforms/SparseTensorRewriting.cpp      |  4 +---
 .../SPIRV/Deserialization/Deserializer.cpp    |  4 ++--
 20 files changed, 71 insertions(+), 88 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3da5deeb4ec7e2..b523374f6c06b5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
   let extraClassDeclaration = [{
     // Add an entry block to an empty function, and set up the block arguments
     // to match the signature of the function.
-    Block *addEntryBlock();
+    Block *addEntryBlock(OpBuilder &builder);
 
     bool isVarArg() { return getFunctionType().isVarArg(); }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 36ad6755cab25e..991e753d1b3593 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -285,7 +285,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
 
     // Adds an empty entry block and loop merge block containing one
     // spirv.mlir.merge op.
-    void addEntryAndMergeBlock();
+    void addEntryAndMergeBlock(OpBuilder &builder);
   }];
 
   let hasOpcode = 0;
@@ -427,7 +427,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
     Block *getMergeBlock();
 
     /// Adds a selection merge block containing one spirv.mlir.merge op.
-    void addMergeBlock();
+    void addMergeBlock(OpBuilder &builder);
 
     /// Creates a spirv.mlir.selection op for `if (<condition>) then { <thenBody> }`
     /// with `builder`. `builder`'s insertion point will remain at after the
diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
index 970a781c998b98..873853eba0b175 100644
--- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td
+++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
@@ -131,6 +131,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
     static void buildWithEntryBlock(
         OpBuilder &builder, OperationState &state, StringRef name, Type type,
         ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
+      OpBuilder::InsertionGuard g(builder);
       state.addAttribute(SymbolTable::getSymbolAttrName(),
                         builder.getStringAttr(name));
       state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
@@ -139,8 +140,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
 
       // Add the function body.
       Region *bodyRegion = state.addRegion();
-      Block *body = new Block();
-      bodyRegion->push_back(body);
+      Block *body = builder.createBlock(bodyRegion);
       for (Type input : inputTypes)
         body->addArgument(input, state.location);
     }
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 0ab53ce7e3327e..77603739137614 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -259,7 +259,7 @@ static void addResumeFunction(ModuleOp module) {
       kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
   resumeOp.setPrivate();
 
-  auto *block = resumeOp.addEntryBlock();
+  auto *block = resumeOp.addEntryBlock(moduleBuilder);
   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
 
   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index 363e5f9b8cefe7..d3ee89743da9db 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -98,12 +98,10 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
       loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
       loopVariablesNextIter);
 
-  auto *afterBlock = new Block;
-  whileOp.getAfter().push_back(afterBlock);
+  Block *afterBlock = builder.createBlock(&whileOp.getAfter());
   afterBlock->addArguments(
       loopVariablesInit.getTypes(),
       SmallVector<Location>(loopVariablesInit.size(), loc));
-  builder.setInsertionPointToEnd(afterBlock);
   builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
 
   return whileOp.getOperation();
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index bd50c67fb87958..53b44aa3241bb1 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -135,7 +135,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
   propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
 
   OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
+  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
 
   SmallVector<Value, 8> args;
   size_t argOffset = resultStructType ? 1 : 0;
@@ -203,7 +203,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
 
   // The wrapper that we synthetize here should only be visible in this module.
   newFuncOp.setLinkage(LLVM::Linkage::Private);
-  builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
+  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
 
   // Get a ValueRange containing arguments.
   FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2bfca303b5fd48..2dc42f0a85e669 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -520,9 +520,7 @@ struct GlobalMemrefOpLowering
         global, arrayTy, global.getConstant(), linkage, global.getSymName(),
         initialValue, alignment, *addressSpace);
     if (!global.isExternal() && global.isUninitialized()) {
-      Block *blk = new Block();
-      newGlobal.getInitializerRegion().push_back(blk);
-      rewriter.setInsertionPointToStart(blk);
+      rewriter.createBlock(&newGlobal.getInitializerRegion());
       Value undef[] = {
           rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
       rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index febfe97f6c0a99..d90cf931385fcc 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -138,14 +138,13 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     // from header to merge.
     auto loc = forOp.getLoc();
     auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
-    loopOp.addEntryAndMergeBlock();
+    loopOp.addEntryAndMergeBlock(rewriter);
 
     OpBuilder::InsertionGuard guard(rewriter);
     // Create the block for the header.
-    auto *header = new Block();
-    // Insert the header.
-    loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1),
-                                        header);
+    Block *header = rewriter.createBlock(&loopOp.getBody(),
+                                         getBlockIt(loopOp.getBody(), 1));
+    rewriter.setInsertionPointAfter(loopOp);
 
     // Create the new induction variable to use.
     Value adapLowerBound = adaptor.getLowerBound();
@@ -342,7 +341,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = whileOp.getLoc();
     auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
-    loopOp.addEntryAndMergeBlock();
+    loopOp.addEntryAndMergeBlock(rewriter);
 
     Region &beforeRegion = whileOp.getBefore();
     Region &afterRegion = whileOp.getAfter();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c4b13193f4e773..a4df863ab08342 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1813,6 +1813,8 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
          "upper bound operand count does not match the affine map");
   assert(step > 0 && "step has to be a positive integer constant");
 
+  OpBuilder::InsertionGuard guard(builder);
+
   // Set variadic segment sizes.
   result.addAttribute(
       getOperandSegmentSizeAttr(),
@@ -1841,12 +1843,11 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
   // Create a region and a block for the body.  The argument of the region is
   // the loop induction variable.
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
-  Block &bodyBlock = bodyRegion->front();
+  Block *bodyBlock = builder.createBlock(bodyRegion);
   Value inductionVar =
-      bodyBlock.addArgument(builder.getIndexType(), result.location);
+      bodyBlock->addArgument(builder.getIndexType(), result.location);
   for (Value val : iterArgs)
-    bodyBlock.addArgument(val.getType(), val.getLoc());
+    bodyBlock->addArgument(val.getType(), val.getLoc());
 
   // Create the default terminator if the builder is not provided and if the
   // iteration arguments are not provided. Otherwise, leave this to the caller
@@ -1855,9 +1856,9 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
     ensureTerminator(*bodyRegion, builder, result.location);
   } else if (bodyBuilder) {
     OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(&bodyBlock);
+    builder.setInsertionPointToStart(bodyBlock);
     bodyBuilder(builder, result.location, inductionVar,
-                bodyBlock.getArguments().drop_front());
+                bodyBlock->getArguments().drop_front());
   }
 }
 
@@ -2895,18 +2896,20 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result,
                        TypeRange resultTypes, IntegerSet set, ValueRange args,
                        bool withElseRegion) {
   assert(resultTypes.empty() || withElseRegion);
+  OpBuilder::InsertionGuard guard(builder);
+
   result.addTypes(resultTypes);
   result.addOperands(args);
   result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
 
   Region *thenRegion = result.addRegion();
-  thenRegion->push_back(new Block());
+  builder.createBlock(thenRegion);
   if (resultTypes.empty())
     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
 
   Region *elseRegion = result.addRegion();
   if (withElseRegion) {
-    elseRegion->push_back(new Block());
+    builder.createBlock(elseRegion);
     if (resultTypes.empty())
       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
   }
@@ -3693,6 +3696,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
          "expected upper bound maps to have as many inputs as upper bound "
          "operands");
 
+  OpBuilder::InsertionGuard guard(builder);
   result.addTypes(resultTypes);
 
   // Convert the reductions to integer attributes.
@@ -3738,11 +3742,11 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
 
   // Create a region and a block for the body.
   auto *bodyRegion = result.addRegion();
-  auto *body = new Block();
+  Block *body = builder.createBlock(bodyRegion);
+
   // Add all the block arguments.
   for (unsigned i = 0, e = steps.size(); i < e; ++i)
     body->addArgument(IndexType::get(builder.getContext()), result.location);
-  bodyRegion->push_back(body);
   if (resultTypes.empty())
     ensureTerminator(*bodyRegion, builder, result.location);
 }
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 5f583f36cd2cb8..a3e3f80954efce 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -68,7 +68,7 @@ void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
                       TypeRange resultTypes, ValueRange dependencies,
                       ValueRange operands, BodyBuilderFn bodyBuilder) {
-
+  OpBuilder::InsertionGuard guard(builder);
   result.addOperands(dependencies);
   result.addOperands(operands);
 
@@ -87,26 +87,21 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
 
   // Add a body region with block arguments as unwrapped async value operands.
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
-  Block &bodyBlock = bodyRegion->front();
+  Block *bodyBlock = builder.createBlock(bodyRegion);
   for (Value operand : operands) {
     auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
-    bodyBlock.addArgument(valueType ? valueType.getValueType()
-                                    : operand.getType(),
-                          operand.getLoc());
+    bodyBlock->addArgument(valueType ? valueType.getValueType()
+                                     : operand.getType(),
+                           operand.getLoc());
   }
 
   // Create the default terminator if the builder is not provided and if the
   // expected result is empty. Otherwise, leave this to the caller
   // because we don't know which values to return from the execute op.
   if (resultTypes.empty() && !bodyBuilder) {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(&bodyBlock);
     builder.create<async::YieldOp>(result.location, ValueRange());
   } else if (bodyBuilder) {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(&bodyBlock);
-    bodyBuilder(builder, result.location, bodyBlock.getArguments());
+    bodyBuilder(builder, result.location, bodyBlock->getArguments());
   }
 }
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 0fe2c0dcfc7c53..4df8149b94c95f 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -261,20 +261,20 @@ LogicalResult ExpressionOp::verify() {
 
 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
                   Value ub, Value step, BodyBuilderFn bodyBuilder) {
+  OpBuilder::InsertionGuard g(builder);
   result.addOperands({lb, ub, step});
   Type t = lb.getType();
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
-  Block &bodyBlock = bodyRegion->front();
-  bodyBlock.addArgument(t, result.location);
+  Block *bodyBlock = builder.createBlock(bodyRegion);
+  bodyBlock->addArgument(t, result.location);
 
   // Create the default terminator if the builder is not provided.
   if (!bodyBuilder) {
     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
   } else {
     OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(&bodyBlock);
-    bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+    builder.setInsertionPointToStart(bodyBlock);
+    bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
   }
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f4042a60541a6a..3ba6ac6ccc8142 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2165,11 +2165,10 @@ LogicalResult ShuffleVectorOp::verify() {
 //===----------------------------------------------------------------------===//
 
 // Add the entry block to the function.
-Block *LLVMFuncOp::addEntryBlock() {
+Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
   assert(empty() && "function already has an entry block");
-
-  auto *entry = new Block;
-  push_back(entry);
+  OpBuilder::InsertionGuard g(builder);
+  Block *entry = builder.createBlock(&getBody());
 
   // FIXME: Allow passing in proper locations for the entry arguments.
   LLVMFunctionType type = getFunctionType();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 559ffda4494d2b..c46e3694b70ecd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -132,12 +132,10 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
         newIndexingMaps, genericOp.getIteratorTypesArray(),
         /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
 
+    OpBuilder::InsertionGuard guard(rewriter);
     Region &region = newOp.getRegion();
-    Block *block = new Block();
-    region.push_back(block);
+    Block *block = rewriter.createBlock(&region);
     IRMapping mapper;
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(block);
     for (auto bbarg : genericOp.getRegionInputArgs())
       mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 286b07669a47f5..0d8d670904f2a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -178,11 +178,9 @@ static void generateFusedElementwiseOpRegion(
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
-  Block *fusedBlock = new Block();
-  fusedOp.getRegion().push_back(fusedBlock);
-  IRMapping mapper;
   OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(fusedBlock);
+  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+  IRMapping mapper;
 
   // 2. Add an index operation for every fused loop dimension and use the
   // `consumerToProducerLoopsMap` to map the producer indices.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..43c408a97687ce 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -275,14 +275,13 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
   auto transposeOp =
       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
                           indexingMaps, iteratorTypes);
-  Region &body = transposeOp.getRegion();
-  body.push_back(new Block());
-  body.front().addArguments({elementType, elementType}, {loc, loc});
 
   // Create the body of the transpose operation.
   OpBuilder::InsertionGuard g(b);
-  b.setInsertionPointToEnd(&body.front());
-  b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
+  Region &body = transposeOp.getRegion();
+  Block *bodyBlock = b.createBlock(&body, /*insertPt=*/{},
+                                   {elementType, elementType}, {loc, loc});
+  b.create<YieldOp>(loc, bodyBlock->getArgument(0));
   return transposeOp;
 }
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index beb7e721ca53b8..248193481acfc6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1420,6 +1420,7 @@ OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
 
 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
                                Value memref, ValueRange ivs) {
+  OpBuilder::InsertionGuard g(builder);
   result.addOperands(memref);
   result.addOperands(ivs);
 
@@ -1428,7 +1429,7 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
     result.addTypes(elementType);
 
     Region *bodyRegion = result.addRegion();
-    bodyRegion->push_back(new Block());
+    builder.createBlock(bodyRegion);
     bodyRegion->addArgument(elementType, memref.getLoc());
   }
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..7170a899069ee3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -365,12 +365,11 @@ Block *LoopOp::getMergeBlock() {
   return &getBody().back();
 }
 
-void LoopOp::addEntryAndMergeBlock() {
+void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
   assert(getBody().empty() && "entry and merge block already exist");
-  getBody().push_back(new Block());
-  auto *mergeBlock = new Block();
-  getBody().push_back(mergeBlock);
-  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+  OpBuilder::InsertionGuard g(builder);
+  builder.createBlock(&getBody());
+  builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
   builder.create<spirv::MergeOp>(getLoc());
@@ -525,11 +524,10 @@ Block *SelectionOp::getMergeBlock() {
   return &getBody().back();
 }
 
-void SelectionOp::addMergeBlock() {
+void SelectionOp::addMergeBlock(OpBuilder &builder) {
   assert(getBody().empty() && "entry and merge block already exist");
-  auto *mergeBlock = new Block();
-  getBody().push_back(mergeBlock);
-  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+  OpBuilder::InsertionGuard guard(builder);
+  builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
   builder.create<spirv::MergeOp>(getLoc());
@@ -542,7 +540,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
   auto selectionOp =
       builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
 
-  selectionOp.addMergeBlock();
+  selectionOp.addMergeBlock(builder);
   Block *mergeBlock = selectionOp.getMergeBlock();
   Block *thenBlock = nullptr;
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4f829db1305c85..d9ee39a4e8dd32 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -375,15 +375,13 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
 void AssumingOp::build(
     OpBuilder &builder, OperationState &result, Value witness,
     function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
+  OpBuilder::InsertionGuard g(builder);
 
   result.addOperands(witness);
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
-  Block &bodyBlock = bodyRegion->front();
+  builder.createBlock(bodyRegion);
 
   // Build body.
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToStart(&bodyBlock);
   SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
   builder.create<AssumingYieldOp>(result.location, yieldValues);
 
@@ -1904,23 +1902,23 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 
 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
                      ValueRange initVals) {
+  OpBuilder::InsertionGuard g(builder);
   result.addOperands(shape);
   result.addOperands(initVals);
 
   Region *bodyRegion = result.addRegion();
-  bodyRegion->push_back(new Block);
-  Block &bodyBlock = bodyRegion->front();
-  bodyBlock.addArgument(builder.getIndexType(), result.location);
+  Block *bodyBlock = builder.createBlock(
+      bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
 
   Type elementType;
   if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
     elementType = tensorType.getElementType();
   else
     elementType = SizeType::get(builder.getContext());
-  bodyBlock.addArgument(elementType, shape.getLoc());
+  bodyBlock->addArgument(elementType, shape.getLoc());
 
   for (Value initVal : initVals) {
-    bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
+    bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
     result.addTypes(initVal.getType());
   }
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2ccb2361b5efe1..1bcc131781d34d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -299,8 +299,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     Block &prodBlock = prod.getRegion().front();
     Block &consBlock = op.getRegion().front();
     IRMapping mapper;
-    Block *fusedBlock = new Block();
-    fusedOp.getRegion().push_back(fusedBlock);
+    Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
     unsigned num = prodBlock.getNumArguments();
     for (unsigned i = 0; i < num - 1; i++)
       addArg(mapper, fusedBlock, prodBlock.getArgument(i));
@@ -309,7 +308,6 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     // Clone bodies of the producer and consumer in new evaluation order.
     auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
     auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
-    rewriter.setInsertionPointToStart(fusedBlock);
     Value last;
     for (auto &op : prodBlock.without_terminator())
       if (&op != acc) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 83ef01b4e3a467..bfd289eb350b59 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1829,7 +1829,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
 
   auto control = static_cast<spirv::SelectionControl>(selectionControl);
   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
-  selectionOp.addMergeBlock();
+  selectionOp.addMergeBlock(builder);
 
   return selectionOp;
 }
@@ -1841,7 +1841,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
 
   auto control = static_cast<spirv::LoopControl>(loopControl);
   auto loopOp = builder.create<spirv::LoopOp>(location, control);
-  loopOp.addEntryAndMergeBlock();
+  loopOp.addEntryAndMergeBlock(builder);
 
   return loopOp;
 }



More information about the Mlir-commits mailing list