[Mlir-commits] [mlir] [mlir] Use `OpBuilder::createBlock` in op builders and patterns (PR #82770)
Matthias Springer
llvmlistbot at llvm.org
Fri Feb 23 06:21:34 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/82770
When creating a new block in (conversion) rewrite patterns, `OpBuilder::createBlock` must be used. Otherwise, no `notifyBlockInserted` notification is sent to the listener.
Note: The dialect conversion relies on listener notifications to keep track of IR modifications. Creating blocks without the builder API can lead to memory leaks during rollback.
>From dffe3a9b9d52756fac42bbb477247397f15b43cb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 23 Feb 2024 14:20:26 +0000
Subject: [PATCH] [mlir] Use `OpBuilder::createBlock` in op builders and
patterns
When creating a new block in (conversion) rewrite patterns, `OpBuilder::createBlock` must be used. Otherwise, no `notifyBlockInserted` notification is sent to the listener.
Note: The dialect conversion relies on listener notifications to keep track of IR modifications. Creating blocks without the builder API can lead to memory leaks during rollback.
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 2 +-
.../Dialect/SPIRV/IR/SPIRVControlFlowOps.td | 4 ++--
.../mlir/Interfaces/FunctionInterfaces.td | 4 ++--
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 2 +-
.../ControlFlowToSCF/ControlFlowToSCF.cpp | 4 +---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 4 ++--
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 4 +---
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 11 ++++-----
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 24 +++++++++++--------
mlir/lib/Dialect/Async/IR/Async.cpp | 17 +++++--------
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 10 ++++----
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 7 +++---
.../Linalg/Transforms/DropUnitDims.cpp | 6 ++---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 ++---
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 9 ++++---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 3 ++-
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 18 +++++++-------
mlir/lib/Dialect/Shape/IR/Shape.cpp | 16 ++++++-------
.../Transforms/SparseTensorRewriting.cpp | 4 +---
.../SPIRV/Deserialization/Deserializer.cpp | 4 ++--
20 files changed, 71 insertions(+), 88 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3da5deeb4ec7e2..b523374f6c06b5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1456,7 +1456,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
let extraClassDeclaration = [{
// Add an entry block to an empty function, and set up the block arguments
// to match the signature of the function.
- Block *addEntryBlock();
+ Block *addEntryBlock(OpBuilder &builder);
bool isVarArg() { return getFunctionType().isVarArg(); }
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 36ad6755cab25e..991e753d1b3593 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -285,7 +285,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
// Adds an empty entry block and loop merge block containing one
// spirv.mlir.merge op.
- void addEntryAndMergeBlock();
+ void addEntryAndMergeBlock(OpBuilder &builder);
}];
let hasOpcode = 0;
@@ -427,7 +427,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
Block *getMergeBlock();
/// Adds a selection merge block containing one spirv.mlir.merge op.
- void addMergeBlock();
+ void addMergeBlock(OpBuilder &builder);
/// Creates a spirv.mlir.selection op for `if (<condition>) then { <thenBody> }`
/// with `builder`. `builder`'s insertion point will remain at after the
diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
index 970a781c998b98..873853eba0b175 100644
--- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td
+++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
@@ -131,6 +131,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
static void buildWithEntryBlock(
OpBuilder &builder, OperationState &state, StringRef name, Type type,
ArrayRef<NamedAttribute> attrs, TypeRange inputTypes) {
+ OpBuilder::InsertionGuard g(builder);
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name),
@@ -139,8 +140,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
// Add the function body.
Region *bodyRegion = state.addRegion();
- Block *body = new Block();
- bodyRegion->push_back(body);
+ Block *body = builder.createBlock(bodyRegion);
for (Type input : inputTypes)
body->addArgument(input, state.location);
}
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 0ab53ce7e3327e..77603739137614 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -259,7 +259,7 @@ static void addResumeFunction(ModuleOp module) {
kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
- auto *block = resumeOp.addEntryBlock();
+ auto *block = resumeOp.addEntryBlock(moduleBuilder);
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index 363e5f9b8cefe7..d3ee89743da9db 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -98,12 +98,10 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
loopVariablesNextIter);
- auto *afterBlock = new Block;
- whileOp.getAfter().push_back(afterBlock);
+ Block *afterBlock = builder.createBlock(&whileOp.getAfter());
afterBlock->addArguments(
loopVariablesInit.getTypes(),
SmallVector<Location>(loopVariablesInit.size(), loc));
- builder.setInsertionPointToEnd(afterBlock);
builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
return whileOp.getOperation();
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index bd50c67fb87958..53b44aa3241bb1 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -135,7 +135,7 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
+ rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
SmallVector<Value, 8> args;
size_t argOffset = resultStructType ? 1 : 0;
@@ -203,7 +203,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
// The wrapper that we synthetize here should only be visible in this module.
newFuncOp.setLinkage(LLVM::Linkage::Private);
- builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
+ builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
// Get a ValueRange containing arguments.
FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2bfca303b5fd48..2dc42f0a85e669 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -520,9 +520,7 @@ struct GlobalMemrefOpLowering
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
initialValue, alignment, *addressSpace);
if (!global.isExternal() && global.isUninitialized()) {
- Block *blk = new Block();
- newGlobal.getInitializerRegion().push_back(blk);
- rewriter.setInsertionPointToStart(blk);
+ rewriter.createBlock(&newGlobal.getInitializerRegion());
Value undef[] = {
rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index febfe97f6c0a99..d90cf931385fcc 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -138,14 +138,13 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// from header to merge.
auto loc = forOp.getLoc();
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
- loopOp.addEntryAndMergeBlock();
+ loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
// Create the block for the header.
- auto *header = new Block();
- // Insert the header.
- loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1),
- header);
+ Block *header = rewriter.createBlock(&loopOp.getBody(),
+ getBlockIt(loopOp.getBody(), 1));
+ rewriter.setInsertionPointAfter(loopOp);
// Create the new induction variable to use.
Value adapLowerBound = adaptor.getLowerBound();
@@ -342,7 +341,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
- loopOp.addEntryAndMergeBlock();
+ loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
Region &afterRegion = whileOp.getAfter();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c4b13193f4e773..a4df863ab08342 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1813,6 +1813,8 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
+ OpBuilder::InsertionGuard guard(builder);
+
// Set variadic segment sizes.
result.addAttribute(
getOperandSegmentSizeAttr(),
@@ -1841,12 +1843,11 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
// Create a region and a block for the body. The argument of the region is
// the loop induction variable.
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
+ Block *bodyBlock = builder.createBlock(bodyRegion);
Value inductionVar =
- bodyBlock.addArgument(builder.getIndexType(), result.location);
+ bodyBlock->addArgument(builder.getIndexType(), result.location);
for (Value val : iterArgs)
- bodyBlock.addArgument(val.getType(), val.getLoc());
+ bodyBlock->addArgument(val.getType(), val.getLoc());
// Create the default terminator if the builder is not provided and if the
// iteration arguments are not provided. Otherwise, leave this to the caller
@@ -1855,9 +1856,9 @@ void AffineForOp::build(OpBuilder &builder, OperationState &result,
ensureTerminator(*bodyRegion, builder, result.location);
} else if (bodyBuilder) {
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
+ builder.setInsertionPointToStart(bodyBlock);
bodyBuilder(builder, result.location, inductionVar,
- bodyBlock.getArguments().drop_front());
+ bodyBlock->getArguments().drop_front());
}
}
@@ -2895,18 +2896,20 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, IntegerSet set, ValueRange args,
bool withElseRegion) {
assert(resultTypes.empty() || withElseRegion);
+ OpBuilder::InsertionGuard guard(builder);
+
result.addTypes(resultTypes);
result.addOperands(args);
result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
Region *thenRegion = result.addRegion();
- thenRegion->push_back(new Block());
+ builder.createBlock(thenRegion);
if (resultTypes.empty())
AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
Region *elseRegion = result.addRegion();
if (withElseRegion) {
- elseRegion->push_back(new Block());
+ builder.createBlock(elseRegion);
if (resultTypes.empty())
AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
}
@@ -3693,6 +3696,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
"expected upper bound maps to have as many inputs as upper bound "
"operands");
+ OpBuilder::InsertionGuard guard(builder);
result.addTypes(resultTypes);
// Convert the reductions to integer attributes.
@@ -3738,11 +3742,11 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
// Create a region and a block for the body.
auto *bodyRegion = result.addRegion();
- auto *body = new Block();
+ Block *body = builder.createBlock(bodyRegion);
+
// Add all the block arguments.
for (unsigned i = 0, e = steps.size(); i < e; ++i)
body->addArgument(IndexType::get(builder.getContext()), result.location);
- bodyRegion->push_back(body);
if (resultTypes.empty())
ensureTerminator(*bodyRegion, builder, result.location);
}
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 5f583f36cd2cb8..a3e3f80954efce 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -68,7 +68,7 @@ void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
void ExecuteOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, ValueRange dependencies,
ValueRange operands, BodyBuilderFn bodyBuilder) {
-
+ OpBuilder::InsertionGuard guard(builder);
result.addOperands(dependencies);
result.addOperands(operands);
@@ -87,26 +87,21 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
// Add a body region with block arguments as unwrapped async value operands.
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
+ Block *bodyBlock = builder.createBlock(bodyRegion);
for (Value operand : operands) {
auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
- bodyBlock.addArgument(valueType ? valueType.getValueType()
- : operand.getType(),
- operand.getLoc());
+ bodyBlock->addArgument(valueType ? valueType.getValueType()
+ : operand.getType(),
+ operand.getLoc());
}
// Create the default terminator if the builder is not provided and if the
// expected result is empty. Otherwise, leave this to the caller
// because we don't know which values to return from the execute op.
if (resultTypes.empty() && !bodyBuilder) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
builder.create<async::YieldOp>(result.location, ValueRange());
} else if (bodyBuilder) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
- bodyBuilder(builder, result.location, bodyBlock.getArguments());
+ bodyBuilder(builder, result.location, bodyBlock->getArguments());
}
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 0fe2c0dcfc7c53..4df8149b94c95f 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -261,20 +261,20 @@ LogicalResult ExpressionOp::verify() {
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, BodyBuilderFn bodyBuilder) {
+ OpBuilder::InsertionGuard g(builder);
result.addOperands({lb, ub, step});
Type t = lb.getType();
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
- bodyBlock.addArgument(t, result.location);
+ Block *bodyBlock = builder.createBlock(bodyRegion);
+ bodyBlock->addArgument(t, result.location);
// Create the default terminator if the builder is not provided.
if (!bodyBuilder) {
ForOp::ensureTerminator(*bodyRegion, builder, result.location);
} else {
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
- bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ builder.setInsertionPointToStart(bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f4042a60541a6a..3ba6ac6ccc8142 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2165,11 +2165,10 @@ LogicalResult ShuffleVectorOp::verify() {
//===----------------------------------------------------------------------===//
// Add the entry block to the function.
-Block *LLVMFuncOp::addEntryBlock() {
+Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
assert(empty() && "function already has an entry block");
-
- auto *entry = new Block;
- push_back(entry);
+ OpBuilder::InsertionGuard g(builder);
+ Block *entry = builder.createBlock(&getBody());
// FIXME: Allow passing in proper locations for the entry arguments.
LLVMFunctionType type = getFunctionType();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 559ffda4494d2b..c46e3694b70ecd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -132,12 +132,10 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
newIndexingMaps, genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ OpBuilder::InsertionGuard guard(rewriter);
Region ®ion = newOp.getRegion();
- Block *block = new Block();
- region.push_back(block);
+ Block *block = rewriter.createBlock(®ion);
IRMapping mapper;
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(block);
for (auto bbarg : genericOp.getRegionInputArgs())
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 286b07669a47f5..0d8d670904f2a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -178,11 +178,9 @@ static void generateFusedElementwiseOpRegion(
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
- Block *fusedBlock = new Block();
- fusedOp.getRegion().push_back(fusedBlock);
- IRMapping mapper;
OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(fusedBlock);
+ Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+ IRMapping mapper;
// 2. Add an index operation for every fused loop dimension and use the
// `consumerToProducerLoopsMap` to map the producer indices.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5d220c6cdd7e58..6392c83bde24f5 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -275,14 +275,13 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
auto transposeOp =
b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
indexingMaps, iteratorTypes);
- Region &body = transposeOp.getRegion();
- body.push_back(new Block());
- body.front().addArguments({elementType, elementType}, {loc, loc});
// Create the body of the transpose operation.
OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToEnd(&body.front());
- b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
+ Region &body = transposeOp.getRegion();
+ Block *bodyBlock = b.createBlock(&body);
+ bodyBlock->addArguments({elementType, elementType}, {loc, loc});
+ b.create<YieldOp>(loc, bodyBlock->getArgument(0));
return transposeOp;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index beb7e721ca53b8..248193481acfc6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1420,6 +1420,7 @@ OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) {
+ OpBuilder::InsertionGuard g(builder);
result.addOperands(memref);
result.addOperands(ivs);
@@ -1428,7 +1429,7 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(elementType);
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block());
+ builder.createBlock(bodyRegion);
bodyRegion->addArgument(elementType, memref.getLoc());
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..7170a899069ee3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -365,12 +365,11 @@ Block *LoopOp::getMergeBlock() {
return &getBody().back();
}
-void LoopOp::addEntryAndMergeBlock() {
+void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
assert(getBody().empty() && "entry and merge block already exist");
- getBody().push_back(new Block());
- auto *mergeBlock = new Block();
- getBody().push_back(mergeBlock);
- OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+ OpBuilder::InsertionGuard g(builder);
+ builder.createBlock(&getBody());
+ builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
builder.create<spirv::MergeOp>(getLoc());
@@ -525,11 +524,10 @@ Block *SelectionOp::getMergeBlock() {
return &getBody().back();
}
-void SelectionOp::addMergeBlock() {
+void SelectionOp::addMergeBlock(OpBuilder &builder) {
assert(getBody().empty() && "entry and merge block already exist");
- auto *mergeBlock = new Block();
- getBody().push_back(mergeBlock);
- OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+ OpBuilder::InsertionGuard guard(builder);
+ builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
builder.create<spirv::MergeOp>(getLoc());
@@ -542,7 +540,7 @@ SelectionOp::createIfThen(Location loc, Value condition,
auto selectionOp =
builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
- selectionOp.addMergeBlock();
+ selectionOp.addMergeBlock(builder);
Block *mergeBlock = selectionOp.getMergeBlock();
Block *thenBlock = nullptr;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4f829db1305c85..c9d979474137b2 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -375,15 +375,13 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
void AssumingOp::build(
OpBuilder &builder, OperationState &result, Value witness,
function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
+ OpBuilder::InsertionGuard g(builder);
result.addOperands(witness);
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
+ builder.createBlock(bodyRegion);
// Build body.
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
builder.create<AssumingYieldOp>(result.location, yieldValues);
@@ -1904,23 +1902,23 @@ bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
ValueRange initVals) {
+ OpBuilder::InsertionGuard g(builder);
result.addOperands(shape);
result.addOperands(initVals);
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
- bodyBlock.addArgument(builder.getIndexType(), result.location);
+ Block *bodyBlock = builder.createBlock(bodyRegion);
+ bodyBlock->addArgument(builder.getIndexType(), result.location);
Type elementType;
if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
elementType = tensorType.getElementType();
else
elementType = SizeType::get(builder.getContext());
- bodyBlock.addArgument(elementType, shape.getLoc());
+ bodyBlock->addArgument(elementType, shape.getLoc());
for (Value initVal : initVals) {
- bodyBlock.addArgument(initVal.getType(), initVal.getLoc());
+ bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
result.addTypes(initVal.getType());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2ccb2361b5efe1..1bcc131781d34d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -299,8 +299,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
Block &prodBlock = prod.getRegion().front();
Block &consBlock = op.getRegion().front();
IRMapping mapper;
- Block *fusedBlock = new Block();
- fusedOp.getRegion().push_back(fusedBlock);
+ Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
unsigned num = prodBlock.getNumArguments();
for (unsigned i = 0; i < num - 1; i++)
addArg(mapper, fusedBlock, prodBlock.getArgument(i));
@@ -309,7 +308,6 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
// Clone bodies of the producer and consumer in new evaluation order.
auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
- rewriter.setInsertionPointToStart(fusedBlock);
Value last;
for (auto &op : prodBlock.without_terminator())
if (&op != acc) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 83ef01b4e3a467..bfd289eb350b59 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1829,7 +1829,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
auto control = static_cast<spirv::SelectionControl>(selectionControl);
auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
- selectionOp.addMergeBlock();
+ selectionOp.addMergeBlock(builder);
return selectionOp;
}
@@ -1841,7 +1841,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
auto control = static_cast<spirv::LoopControl>(loopControl);
auto loopOp = builder.create<spirv::LoopOp>(location, control);
- loopOp.addEntryAndMergeBlock();
+ loopOp.addEntryAndMergeBlock(builder);
return loopOp;
}
More information about the Mlir-commits
mailing list