[Mlir-commits] [mlir] [mlir][NFC] update `mlir/Dialect` create APIs (18/n) (PR #149925)
Maksim Levental
llvmlistbot at llvm.org
Thu Jul 24 12:20:20 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149925
>From 1ff3495fc88aa4d46c3504c7421bab499ac9c7bf Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 21 Jul 2025 18:19:22 -0400
Subject: [PATCH] [mlir][NFC] update `mlir/Dialect` create APIs (18/n)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
.../Dialect/MemRef/IR/MemRefMemorySlot.cpp | 8 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 38 +++--
.../TransformOps/MemRefTransformOps.cpp | 7 +-
.../Transforms/AllocationOpInterfaceImpl.cpp | 11 +-
.../MemRef/Transforms/ComposeSubView.cpp | 4 +-
.../MemRef/Transforms/EmulateNarrowType.cpp | 62 ++++---
.../Dialect/MemRef/Transforms/ExpandOps.cpp | 17 +-
.../MemRef/Transforms/ExpandRealloc.cpp | 34 ++--
.../Transforms/ExpandStridedMetadata.cpp | 39 ++---
.../Transforms/ExtractAddressComputations.cpp | 34 ++--
.../MemRef/Transforms/FlattenMemRefs.cpp | 59 ++++---
.../Transforms/IndependenceTransforms.cpp | 14 +-
.../Dialect/MemRef/Transforms/MultiBuffer.cpp | 17 +-
.../MemRef/Transforms/NormalizeMemRefs.cpp | 4 +-
.../MemRef/Transforms/ReifyResultShapes.cpp | 6 +-
.../ResolveShapedTypeResultDims.cpp | 2 +-
.../Transforms/RuntimeOpVerification.cpp | 154 +++++++++---------
17 files changed, 257 insertions(+), 253 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index c5643f6e2f830..dfa2e4e0376ed 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -85,11 +85,11 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
// TODO: support more types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](MemRefType t) {
- return builder.create<memref::AllocaOp>(getLoc(), t);
+ return memref::AllocaOp::create(builder, getLoc(), t);
})
.Default([&](Type t) {
- return builder.create<arith::ConstantOp>(getLoc(), t,
- builder.getZeroAttr(t));
+ return arith::ConstantOp::create(builder, getLoc(), t,
+ builder.getZeroAttr(t));
});
}
@@ -135,7 +135,7 @@ DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
- auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+ auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr);
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(), elemType});
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 51c813682ce25..74b968c27a62a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -213,9 +213,9 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
- auto newAlloc = rewriter.create<AllocLikeOp>(
- alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
- alloc.getAlignmentAttr());
+ auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
+ dynamicSizes, alloc.getSymbolOperands(),
+ alloc.getAlignmentAttr());
// Insert a cast so we have the same type as the old alloc.
rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
return success();
@@ -797,7 +797,7 @@ void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) {
auto loc = result.location;
- Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
+ Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
build(builder, result, source, indexValue);
}
@@ -1044,9 +1044,9 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
rewriter.setInsertionPointAfter(reshape);
Location loc = dim.getLoc();
Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
+ LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
@@ -1319,8 +1319,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
assert(isa<Attribute>(maybeConstant) &&
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
- Value constantVal = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
+ Value constantVal = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
// yet.
@@ -2548,8 +2549,9 @@ struct CollapseShapeOpMemRefCastFolder
rewriter.modifyOpInPlace(
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
} else {
- Value newOp = rewriter.create<CollapseShapeOp>(
- op->getLoc(), cast.getSource(), op.getReassociationIndices());
+ Value newOp =
+ CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
+ op.getReassociationIndices());
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
return success();
@@ -3006,15 +3008,15 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
Value offset =
op.isDynamicOffset(idx)
? op.getDynamicOffset(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
Value size =
op.isDynamicSize(idx)
? op.getDynamicSize(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
Value stride =
op.isDynamicStride(idx)
? op.getDynamicStride(idx)
- : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
+ : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
res.emplace_back(Range{offset, size, stride});
}
return res;
@@ -3173,8 +3175,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
if (!resultType)
return failure();
- Value newSubView = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), resultType, castOp.getSource(),
+ Value newSubView = SubViewOp::create(
+ rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
subViewOp.getStaticStrides());
@@ -3495,9 +3497,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
return failure();
// Create new ViewOp.
- auto newViewOp = rewriter.create<ViewOp>(
- viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
- viewOp.getByteShift(), newOperands);
+ auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
+ viewOp.getOperand(0), viewOp.getByteShift(),
+ newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 0c03670b4535f..95eb2a9a95bc1 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -155,9 +155,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
Type resultType = alloca.getResult().getType();
OpBuilder builder(rewriter.getContext());
// TODO: Add a better builder for this.
- globalOp = builder.create<memref::GlobalOp>(
- loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
- TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+ globalOp = memref::GlobalOp::create(
+ builder, loc, StringAttr::get(ctx, "alloca"),
+ StringAttr::get(ctx, "private"), TypeAttr::get(resultType),
+ Attribute{}, UnitAttr{}, IntegerAttr{});
symbolTable.insert(globalOp);
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
index c433415944323..75cc39e61656a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
@@ -22,11 +22,11 @@ struct DefaultAllocationInterface
DefaultAllocationInterface, memref::AllocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value alloc) {
- return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+ return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
.getOperation();
}
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
- return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
+ return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc)
.getResult();
}
static ::mlir::HoistingKind getHoistingKind() {
@@ -35,8 +35,9 @@ struct DefaultAllocationInterface
static ::std::optional<::mlir::Operation *>
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
Operation *definingOp = alloc.getDefiningOp();
- return builder.create<memref::AllocaOp>(
- definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
+ return memref::AllocaOp::create(
+ builder, definingOp->getLoc(),
+ cast<MemRefType>(definingOp->getResultTypes()[0]),
definingOp->getOperands(), definingOp->getAttrs());
}
};
@@ -52,7 +53,7 @@ struct DefaultReallocationInterface
DefaultAllocationInterface, memref::ReallocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value realloc) {
- return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
+ return memref::DeallocOp::create(builder, realloc.getLoc(), realloc)
.getOperation();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 7c777e807f08c..106c3b458dbac 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -124,8 +124,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
- Value result = rewriter.create<affine::AffineApplyOp>(
- op.getLoc(), map, affineApplyOperands);
+ Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
+ affineApplyOperands);
offsets.push_back(result);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index ec2bc95291455..556ea1a8e9c40 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
- return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
+ return arith::IndexCastOp::create(builder, loc, dstType, bitOffset);
}
/// When writing a subbyte size, masked bitwise operations are used to only
@@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
auto dstIntegerType = builder.getIntegerType(dstBits);
auto maskRightAlignedAttr =
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
- Value maskRightAligned = builder.create<arith::ConstantOp>(
- loc, dstIntegerType, maskRightAlignedAttr);
+ Value maskRightAligned = arith::ConstantOp::create(
+ builder, loc, dstIntegerType, maskRightAlignedAttr);
Value writeMaskInverse =
- builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset);
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
Value flipVal =
- builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
- return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+ arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr);
+ return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal);
}
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
@@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
const SmallVector<OpFoldResult> &indices,
Value memref) {
auto stridedMetadata =
- builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ memref::ExtractStridedMetadataOp::create(builder, loc, memref);
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
@@ -298,16 +298,16 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Special case 0-rank memref loads.
Value bitsLoad;
if (convertedType.getRank() == 0) {
- bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
- ValueRange{});
+ bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
+ ValueRange{});
} else {
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
- Value newLoad = rewriter.create<memref::LoadOp>(
- loc, adaptor.getMemref(),
+ Value newLoad = memref::LoadOp::create(
+ rewriter, loc, adaptor.getMemref(),
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));
@@ -315,7 +315,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Note, currently only the big-endian is supported.
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
srcBits, dstBits, rewriter);
- bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
+ bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset);
}
// Get the corresponding bits. If the arith computation bitwidth equals
@@ -331,17 +331,17 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
: IntegerType::get(rewriter.getContext(),
resultTy.getIntOrFloatBitWidth());
if (conversionTy == convertedElementType) {
- auto mask = rewriter.create<arith::ConstantOp>(
- loc, convertedElementType,
+ auto mask = arith::ConstantOp::create(
+ rewriter, loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
- result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+ result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask);
} else {
- result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
+ result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad);
}
if (conversionTy != resultTy) {
- result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+ result = arith::BitcastOp::create(rewriter, loc, resultTy, result);
}
rewriter.replaceOp(op, result);
@@ -428,20 +428,20 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Pad the input value with 0s on the left.
Value input = adaptor.getValue();
if (!input.getType().isInteger()) {
- input = rewriter.create<arith::BitcastOp>(
- loc,
+ input = arith::BitcastOp::create(
+ rewriter, loc,
IntegerType::get(rewriter.getContext(),
input.getType().getIntOrFloatBitWidth()),
input);
}
Value extendedInput =
- rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
+ arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
- extendedInput, adaptor.getMemref(),
- ValueRange{});
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
+ extendedInput, adaptor.getMemref(),
+ ValueRange{});
rewriter.eraseOp(op);
return success();
}
@@ -456,16 +456,14 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
dstBits, bitwidthOffset, rewriter);
// Align the value to write with the destination bits
Value alignedVal =
- rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+ arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset);
// Clear destination bits
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
- writeMask, adaptor.getMemref(),
- storeIndices);
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(), storeIndices);
// Write srcs bits to destination
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
- alignedVal, adaptor.getMemref(),
- storeIndices);
+ memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(), storeIndices);
rewriter.eraseOp(op);
return success();
}
@@ -525,8 +523,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
// Transform the offsets, sizes and strides according to the emulation.
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, subViewOp.getViewSource());
+ auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, subViewOp.getViewSource());
OpFoldResult linearizedIndices;
auto strides = stridedMetadata.getConstifiedMixedStrides();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index e6e4c3b07ecb8..17a148cc31dc0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -48,15 +48,15 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
if (op.getType().isDynamicDim(i)) {
- Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
- size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
+ Value index = arith::ConstantIndexOp::create(rewriter, loc, i);
+ size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
if (!isa<IndexType>(size.getType()))
- size = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), size);
+ size = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), size);
sizes[i] = size;
} else {
auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
- size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
+ size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
sizes[i] = sizeAttr;
}
if (stride)
@@ -66,10 +66,11 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
if (i > 0) {
if (stride) {
- stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+ stride = arith::MulIOp::create(rewriter, loc, stride, size);
} else if (op.getType().isDynamicDim(i)) {
- stride = rewriter.create<arith::MulIOp>(
- loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
+ stride = arith::MulIOp::create(
+ rewriter, loc,
+ arith::ConstantIndexOp::create(rewriter, loc, staticStride),
size);
} else {
staticStride *= op.getType().getDimSize(i);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
index 7475d442b7b9a..01d32621b2055 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
@@ -73,7 +73,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
rewriter.getIndexAttr(0));
- currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
+ currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
.getResult();
}
@@ -88,10 +88,10 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// the old buffer is smaller than the requested size.
Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
- Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
- lhs, rhs);
- auto ifOp = rewriter.create<scf::IfOp>(
- loc, cond,
+ Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
+ lhs, rhs);
+ auto ifOp = scf::IfOp::create(
+ rewriter, loc, cond,
[&](OpBuilder &builder, Location loc) {
// Allocate the new buffer. If it is a dynamic memref we need to pass
// an additional operand for the size at runtime, otherwise the static
@@ -100,25 +100,26 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
if (op.getDynamicResultSize())
dynamicSizeOperands.push_back(op.getDynamicResultSize());
- Value newAlloc = builder.create<memref::AllocOp>(
- loc, op.getResult().getType(), dynamicSizeOperands,
+ Value newAlloc = memref::AllocOp::create(
+ builder, loc, op.getResult().getType(), dynamicSizeOperands,
op.getAlignmentAttr());
// Take a subview of the new (bigger) buffer such that we can copy the
// old values over (the copy operation requires both operands to have
// the same shape).
- Value subview = builder.create<memref::SubViewOp>(
- loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
+ Value subview = memref::SubViewOp::create(
+ builder, loc, newAlloc,
+ ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
ArrayRef<OpFoldResult>{currSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
- builder.create<memref::CopyOp>(loc, op.getSource(), subview);
+ memref::CopyOp::create(builder, loc, op.getSource(), subview);
// Insert the deallocation of the old buffer only if requested
// (enabled by default).
if (emitDeallocs)
- builder.create<memref::DeallocOp>(loc, op.getSource());
+ memref::DeallocOp::create(builder, loc, op.getSource());
- builder.create<scf::YieldOp>(loc, newAlloc);
+ scf::YieldOp::create(builder, loc, newAlloc);
},
[&](OpBuilder &builder, Location loc) {
// We need to reinterpret-cast here because either the input or output
@@ -126,11 +127,12 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// dynamic or vice-versa. If both are static and the original buffer
// is already bigger than the requested size, the cast represents a
// subview operation.
- Value casted = builder.create<memref::ReinterpretCastOp>(
- loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
- rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
+ Value casted = memref::ReinterpretCastOp::create(
+ builder, loc, cast<MemRefType>(op.getResult().getType()),
+ op.getSource(), rewriter.getIndexAttr(0),
+ ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
- builder.create<scf::YieldOp>(loc, casted);
+ scf::YieldOp::create(builder, loc, casted);
});
rewriter.replaceOp(op, ifOp.getResult(0));
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 2ba798f48ac7c..9771bd2aaa143 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -66,7 +66,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
#ifndef NDEBUG
@@ -577,7 +577,7 @@ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
// Collect statically known information.
auto [strides, offset] = sourceType.getStridesAndOffset();
@@ -828,14 +828,14 @@ struct ExtractStridedMetadataOpAllocFolder
if (allocLikeOp.getType() == baseBufferType)
results.push_back(allocLikeOp);
else
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, allocLikeOp, offset,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, allocLikeOp, offset,
/*sizes=*/ArrayRef<int64_t>(),
/*strides=*/ArrayRef<int64_t>()));
}
// Offset.
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
for (OpFoldResult size : sizes)
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
@@ -900,19 +900,19 @@ struct ExtractStridedMetadataOpGetGlobalFolder
if (getGlobalOp.getType() == baseBufferType)
results.push_back(getGlobalOp);
else
- results.push_back(rewriter.create<memref::ReinterpretCastOp>(
- loc, baseBufferType, getGlobalOp, offset,
+ results.push_back(memref::ReinterpretCastOp::create(
+ rewriter, loc, baseBufferType, getGlobalOp, offset,
/*sizes=*/ArrayRef<int64_t>(),
/*strides=*/ArrayRef<int64_t>()));
// Offset.
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
for (auto size : sizes)
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size));
for (auto stride : strides)
- results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
+ results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride));
rewriter.replaceOp(op, results);
return success();
@@ -1008,9 +1008,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, reinterpretCastOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, reinterpretCastOp.getSource());
// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();
@@ -1082,9 +1081,8 @@ class ExtractStridedMetadataOpCastFolder
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc,
- castOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, castOp.getSource());
// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();
@@ -1142,9 +1140,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
if (!memSpaceCastOp)
return failure();
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, memSpaceCastOp.getSource());
+ auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, memSpaceCastOp.getSource());
SmallVector<Value> results(newExtractStridedMetadata.getResults());
// As with most other strided metadata rewrite patterns, don't introduce
// a use of the base pointer where non existed. This needs to happen here,
@@ -1158,8 +1155,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder
MemRefType::Builder newTypeBuilder(baseBufferType);
newTypeBuilder.setMemorySpace(
memSpaceCastOp.getResult().getType().getMemorySpace());
- results[0] = rewriter.create<memref::MemorySpaceCastOp>(
- loc, Type{newTypeBuilder}, baseBuffer);
+ results[0] = memref::MemorySpaceCastOp::create(
+ rewriter, loc, Type{newTypeBuilder}, baseBuffer);
} else {
results[0] = nullptr;
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 2f5c9436fb8c7..0946da8e4e919 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -42,8 +42,8 @@ static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
memref::LoadOp loadOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = loadOp.getLoc();
- return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
- loadOp.getNontemporal());
+ return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
+ loadOp.getNontemporal());
}
// Matches getViewSizeForEachDim specs for LoadOp.
@@ -72,9 +72,8 @@ static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
memref::StoreOp storeOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = storeOp.getLoc();
- return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
- srcMemRef, indices,
- storeOp.getNontemporal());
+ return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
+ srcMemRef, indices, storeOp.getNontemporal());
}
// Matches getViewSizeForEachDim specs for StoreOp.
@@ -104,8 +103,8 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = ldMatrixOp.getLoc();
- return rewriter.create<nvgpu::LdMatrixOp>(
- loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
+ return nvgpu::LdMatrixOp::create(
+ rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
}
@@ -132,8 +131,8 @@ rebuildTransferReadOp(RewriterBase &rewriter,
vector::TransferReadOp transferReadOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = transferReadOp.getLoc();
- return rewriter.create<vector::TransferReadOp>(
- loc, transferReadOp.getResult().getType(), srcMemRef, indices,
+ return vector::TransferReadOp::create(
+ rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices,
transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
}
@@ -150,8 +149,8 @@ rebuildTransferWriteOp(RewriterBase &rewriter,
vector::TransferWriteOp transferWriteOp, Value srcMemRef,
ArrayRef<Value> indices) {
Location loc = transferWriteOp.getLoc();
- return rewriter.create<vector::TransferWriteOp>(
- loc, transferWriteOp.getValue(), srcMemRef, indices,
+ return vector::TransferWriteOp::create(
+ rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices,
transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
transferWriteOp.getInBoundsAttr());
}
@@ -182,9 +181,8 @@ static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
LoadStoreLikeOp loadStoreLikeOp) {
Location loc = loadStoreLikeOp.getLoc();
- auto extractStridedMetadataOp =
- rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, getSrcMemRef(loadStoreLikeOp));
+ auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, getSrcMemRef(loadStoreLikeOp));
SmallVector<OpFoldResult> srcSizes =
extractStridedMetadataOp.getConstifiedMixedSizes();
SmallVector<OpFoldResult> indices =
@@ -267,12 +265,12 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// apply them properly to the input indices.
// Therefore the strides multipliers are simply ones.
auto subview =
- rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
- /*offsets=*/indices,
- /*sizes=*/sizes, /*strides=*/ones);
+ memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef,
+ /*offsets=*/indices,
+ /*sizes=*/sizes, /*strides=*/ones);
// Rewrite the load/store with the subview as the base pointer.
SmallVector<Value> zeros(loadStoreRank,
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ arith::ConstantIndexOp::create(rewriter, loc, 0));
LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
rewriter, loadStoreLikeOp, subview.getResult(), zeros);
rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 76f7788c4dcc5..42be847811d52 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -40,8 +40,8 @@ using namespace mlir;
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
OpFoldResult in) {
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
- return rewriter.create<arith::ConstantIndexOp>(
- loc, cast<IntegerAttr>(offsetAttr).getInt());
+ return arith::ConstantIndexOp::create(
+ rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
}
return cast<Value>(in);
}
@@ -60,7 +60,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
}
memref::ExtractStridedMetadataOp stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
OpFoldResult linearizedIndices;
@@ -74,8 +74,8 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
getAsOpFoldResult(indices));
return std::make_pair(
- rewriter.create<memref::ReinterpretCastOp>(
- loc, source,
+ memref::ReinterpretCastOp::create(
+ rewriter, loc, source,
/* offset = */ linearizedInfo.linearizedOffset,
/* shapes = */
ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
@@ -111,7 +111,7 @@ template <typename T>
static void castAllocResult(T oper, T newOper, Location loc,
PatternRewriter &rewriter) {
memref::ExtractStridedMetadataOp stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
oper, cast<MemRefType>(oper.getType()), newOper,
/*offset=*/rewriter.getIndexAttr(0),
@@ -125,63 +125,68 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Location loc = op->getLoc();
llvm::TypeSwitch<Operation *>(op.getOperation())
.template Case<memref::AllocOp>([&](auto oper) {
- auto newAlloc = rewriter.create<memref::AllocOp>(
- loc, cast<MemRefType>(flatMemref.getType()),
+ auto newAlloc = memref::AllocOp::create(
+ rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloc, loc, rewriter);
})
.template Case<memref::AllocaOp>([&](auto oper) {
- auto newAlloca = rewriter.create<memref::AllocaOp>(
- loc, cast<MemRefType>(flatMemref.getType()),
+ auto newAlloca = memref::AllocaOp::create(
+ rewriter, loc, cast<MemRefType>(flatMemref.getType()),
oper.getAlignmentAttr());
castAllocResult(oper, newAlloca, loc, rewriter);
})
.template Case<memref::LoadOp>([&](auto op) {
- auto newLoad = rewriter.create<memref::LoadOp>(
- loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ auto newLoad =
+ memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
+ flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<memref::StoreOp>([&](auto op) {
- auto newStore = rewriter.create<memref::StoreOp>(
- loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ auto newStore =
+ memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
+ flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::LoadOp>([&](auto op) {
- auto newLoad = rewriter.create<vector::LoadOp>(
- loc, op->getResultTypes(), flatMemref, ValueRange{offset});
+ auto newLoad =
+ vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
+ flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.template Case<vector::StoreOp>([&](auto op) {
- auto newStore = rewriter.create<vector::StoreOp>(
- loc, op->getOperands().front(), flatMemref, ValueRange{offset});
+ auto newStore =
+ vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
+ flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.template Case<vector::MaskedLoadOp>([&](auto op) {
- auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
- loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
- op.getPassThru());
+ auto newMaskedLoad = vector::MaskedLoadOp::create(
+ rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
+ op.getMask(), op.getPassThru());
newMaskedLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedLoad.getResult());
})
.template Case<vector::MaskedStoreOp>([&](auto op) {
- auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
- loc, flatMemref, ValueRange{offset}, op.getMask(),
+ auto newMaskedStore = vector::MaskedStoreOp::create(
+ rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
op.getValueToStore());
newMaskedStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedStore);
})
.template Case<vector::TransferReadOp>([&](auto op) {
- auto newTransferRead = rewriter.create<vector::TransferReadOp>(
- loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
+ auto newTransferRead = vector::TransferReadOp::create(
+ rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
+ op.getPadding());
rewriter.replaceOp(op, newTransferRead.getResult());
})
.template Case<vector::TransferWriteOp>([&](auto op) {
- auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
- loc, op.getVector(), flatMemref, ValueRange{offset});
+ auto newTransferWrite = vector::TransferWriteOp::create(
+ rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
.Default([&](auto op) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 35c661ecb886d..66c1aa6bf3fe1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -51,7 +51,7 @@ FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
// Create a new memref::AllocaOp.
Value newAllocaOp =
- b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
+ AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType());
// Create a memref::SubViewOp.
SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
@@ -71,11 +71,11 @@ propagateSubViewOp(RewriterBase &rewriter,
MemRefType newResultType = SubViewOp::inferRankReducedResultType(
op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
op.getMixedSizes(), op.getMixedStrides());
- Value newSubview = rewriter.create<SubViewOp>(
- op.getLoc(), newResultType, conversionOp.getOperand(0),
+ Value newSubview = SubViewOp::create(
+ rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
- auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
- op.getLoc(), op.getType(), newSubview);
+ auto newConversionOp = UnrealizedConversionCastOp::create(
+ rewriter, op.getLoc(), op.getType(), newSubview);
rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
return newConversionOp;
}
@@ -106,8 +106,8 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
for (const auto &it :
llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
- unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
- to->getLoc(), std::get<0>(it.value()).getType(),
+ unrealizedConversions.push_back(UnrealizedConversionCastOp::create(
+ rewriter, to->getLoc(), std::get<0>(it.value()).getType(),
std::get<1>(it.value())));
rewriter.replaceAllUsesWith(from->getResult(it.index()),
unrealizedConversions.back()->getResult(0));
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 0a84962150ead..5d3cec402cab1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -63,9 +63,10 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
- Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
- subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
+ Value newSubview = memref::SubViewOp::create(
+ rewriter, subviewUse->getLoc(), newType, val,
+ subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+ subviewUse.getMixedStrides());
// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -177,8 +178,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
Location loc = allocOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(allocOp);
- auto mbAlloc = rewriter.create<memref::AllocOp>(
- loc, mbMemRefType, ValueRange{}, allocOp->getAttrs());
+ auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
+ ValueRange{}, allocOp->getAttrs());
LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
// 3. Within the loop, build the modular leading index (i.e. each loop
@@ -211,8 +212,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
// Strides is [1, 1 ... 1 ].
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides);
- Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
- offsets, sizes, strides);
+ Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
+ offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
@@ -224,7 +225,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(deallocOp);
auto newDeallocOp =
- rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc);
+ memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
(void)newDeallocOp;
LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
rewriter.eraseOp(deallocOp);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 4ec04321dd3e2..fa7991e6c6a80 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -276,8 +276,8 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
if (!callOp)
continue;
Operation *newCallOp =
- builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
- resultTypes, userOp->getOperands());
+ func::CallOp::create(builder, userOp->getLoc(), callOp.getCalleeAttr(),
+ resultTypes, userOp->getOperands());
bool replacingMemRefUsesFailed = false;
bool returnTypeChanged = false;
for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index 46f9d64ebeb15..d65825bbdf391 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -115,10 +115,12 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
// Update the type.
newRes.setType(reifiedTy);
if (isa<RankedTensorType>(reifiedTy)) {
- newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+ newResults.push_back(
+ tensor::CastOp::create(rewriter, loc, oldTy, newRes));
} else {
assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
- newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+ newResults.push_back(
+ memref::CastOp::create(rewriter, loc, oldTy, newRes));
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895d06ba5..6a81a15f30e47 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -69,7 +69,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
Location loc = dimOp->getLoc();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
dimOp, resultShape,
- rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
+ arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult());
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index d231516884c7d..1f03e9ae8d6a1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -40,19 +40,18 @@ struct AssumeAlignmentOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
- Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
- loc, assumeOp.getMemref());
- Value rest = builder.create<arith::RemUIOp>(
- loc, ptr,
- builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
- Value isAligned = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, rest,
- builder.create<arith::ConstantIndexOp>(loc, 0));
- builder.create<cf::AssertOp>(
- loc, isAligned,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
+ assumeOp.getMemref());
+ Value rest = arith::RemUIOp::create(
+ builder, loc, ptr,
+ arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
+ Value isAligned =
+ arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
+ arith::ConstantIndexOp::create(builder, loc, 0));
+ cf::AssertOp::create(builder, loc, isAligned,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "memref is not aligned to " +
+ std::to_string(assumeOp.getAlignment())));
}
};
@@ -71,15 +70,14 @@ struct CastOpInterface
if (isa<UnrankedMemRefType>(srcType)) {
// Check rank.
- Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
+ Value srcRank = RankOp::create(builder, loc, castOp.getSource());
Value resultRank =
- builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
- Value isSameRank = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcRank, resultRank);
- builder.create<cf::AssertOp>(
- loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
+ Value isSameRank = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
+ cf::AssertOp::create(builder, loc, isSameRank,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch"));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -95,8 +93,9 @@ struct CastOpInterface
MemRefType::get(dynamicShape, resultType.getElementType(),
stridedLayout, resultType.getMemorySpace());
Value helperCast =
- builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
- auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
+ CastOp::create(builder, loc, dynStridesType, castOp.getSource());
+ auto metadataOp =
+ ExtractStridedMetadataOp::create(builder, loc, helperCast);
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
@@ -110,13 +109,13 @@ struct CastOpInterface
continue;
Value srcDimSz =
- builder.create<DimOp>(loc, castOp.getSource(), it.index());
+ DimOp::create(builder, loc, castOp.getSource(), it.index());
Value resultDimSz =
- builder.create<arith::ConstantIndexOp>(loc, it.value());
- Value isSameSz = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
- builder.create<cf::AssertOp>(
- loc, isSameSz,
+ arith::ConstantIndexOp::create(builder, loc, it.value());
+ Value isSameSz = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
+ cf::AssertOp::create(
+ builder, loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size mismatch of dim " + std::to_string(it.index())));
}
@@ -132,13 +131,12 @@ struct CastOpInterface
// Static/dynamic offset -> dynamic offset does not need verification.
Value srcOffset = metadataOp.getResult(1);
Value resultOffsetVal =
- builder.create<arith::ConstantIndexOp>(loc, resultOffset);
- Value isSameOffset = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
- builder.create<cf::AssertOp>(
- loc, isSameOffset,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ arith::ConstantIndexOp::create(builder, loc, resultOffset);
+ Value isSameOffset = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
+ cf::AssertOp::create(builder, loc, isSameOffset,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "offset mismatch"));
}
// Check strides.
@@ -150,11 +148,11 @@ struct CastOpInterface
Value srcStride =
metadataOp.getResult(2 + resultType.getRank() + it.index());
Value resultStrideVal =
- builder.create<arith::ConstantIndexOp>(loc, it.value());
- Value isSameStride = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
- builder.create<cf::AssertOp>(
- loc, isSameStride,
+ arith::ConstantIndexOp::create(builder, loc, it.value());
+ Value isSameStride = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
+ cf::AssertOp::create(
+ builder, loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "stride mismatch of dim " + std::to_string(it.index())));
}
@@ -186,7 +184,7 @@ struct CopyOpInterface
auto getDimSize = [&](Value memRef, MemRefType type,
int64_t dim) -> Value {
return type.isDynamicDim(dim)
- ? builder.create<DimOp>(loc, memRef, dim).getResult()
+ ? DimOp::create(builder, loc, memRef, dim).getResult()
: builder
.create<arith::ConstantIndexOp>(loc,
type.getDimSize(dim))
@@ -194,13 +192,12 @@ struct CopyOpInterface
};
Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
- Value sameDimSize = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
- builder.create<cf::AssertOp>(
- loc, sameDimSize,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ Value sameDimSize = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
+ cf::AssertOp::create(builder, loc, sameDimSize,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "size of " + std::to_string(i) +
+ "-th source/target dim does not match"));
}
}
};
@@ -211,10 +208,11 @@ struct DimOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto dimOp = cast<DimOp>(op);
- Value rank = builder.create<RankOp>(loc, dimOp.getSource());
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- builder.create<cf::AssertOp>(
- loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+ Value rank = RankOp::create(builder, loc, dimOp.getSource());
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ cf::AssertOp::create(
+ builder, loc,
+ generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "index is out of bounds"));
}
@@ -237,7 +235,7 @@ struct LoadStoreOpInterface
}
auto indices = loadStoreOp.getIndices();
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
@@ -247,10 +245,9 @@ struct LoadStoreOpInterface
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
: inBounds;
}
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ cf::AssertOp::create(builder, loc, assertCond,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "out-of-bounds access"));
}
};
@@ -265,10 +262,10 @@ struct SubViewOpInterface
// For each dimension, assert that:
// 0 <= offset < dim_size
// 0 <= offset + (size - 1) * stride < dim_size
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
- builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
+ ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
@@ -281,21 +278,21 @@ struct SubViewOpInterface
Value dimSize = metadataOp.getSizes()[i];
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- builder.create<cf::AssertOp>(
- loc, offsetInBounds,
+ cf::AssertOp::create(
+ builder, loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset " + std::to_string(i) + " is out-of-bounds"));
// Verify that slice does not run out-of-bounds.
- Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
+ Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
- builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
+ arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
Value lastPos =
- builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
+ arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
- builder.create<cf::AssertOp>(
- loc, lastPosInBounds,
+ cf::AssertOp::create(
+ builder, loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
std::to_string(i)));
@@ -315,7 +312,7 @@ struct ExpandShapeOpInterface
for (const auto &it :
llvm::enumerate(expandShapeOp.getReassociationIndices())) {
Value srcDimSz =
- builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
+ DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
int64_t groupSz = 1;
bool foundDynamicDim = false;
for (int64_t resultDim : it.value()) {
@@ -330,18 +327,17 @@ struct ExpandShapeOpInterface
groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
}
Value staticResultDimSz =
- builder.create<arith::ConstantIndexOp>(loc, groupSz);
+ arith::ConstantIndexOp::create(builder, loc, groupSz);
// staticResultDimSz must divide srcDimSz evenly.
Value mod =
- builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
- Value isModZero = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, mod,
- builder.create<arith::ConstantIndexOp>(loc, 0));
- builder.create<cf::AssertOp>(
- loc, isModZero,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
+ Value isModZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::eq, mod,
+ arith::ConstantIndexOp::create(builder, loc, 0));
+ cf::AssertOp::create(builder, loc, isModZero,
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "static result dims in reassoc group do not "
+ "divide src dim evenly"));
}
}
};
More information about the Mlir-commits
mailing list