[Mlir-commits] [mlir] [mlir][NFC] update `mlir/Dialect` create APIs (23/n) (PR #149930)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 15:47:11 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
Patch is 55.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149930.diff
15 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+5-5)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+71-67)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+15-15)
- (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+2-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+41-40)
- (modified) mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (+3-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp (+8-7)
- (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+3-3)
- (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+10-9)
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+11-11)
- (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+1-1)
- (modified) mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp (+31-32)
- (modified) mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp (+2-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+3-3)
- (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+3-3)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 26406ceef082c..7e4a5acb9867d 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -74,12 +74,12 @@ struct CreatorOpShardingInterface
if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
if (!newSharding) {
newSharding =
- builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
+ ShardingOp::create(builder, op->getLoc(), resultShardings[0]);
device =
- builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+ mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh)
.getResults();
- shapeForDevice = builder.create<mesh::ShardShapeOp>(
- op->getLoc(), oldType.getShape(), spmdizedOperands,
+ shapeForDevice = mesh::ShardShapeOp::create(
+ builder, op->getLoc(), oldType.getShape(), spmdizedOperands,
newSharding->getResult(0), device);
}
newOperands.emplace_back(shapeForDevice.getResult()[i]);
@@ -88,7 +88,7 @@ struct CreatorOpShardingInterface
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
}
}
- newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
+ newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b035a53692dcf..7d4b1127a08be 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -53,8 +53,8 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
return op;
if (complex::ConstantOp::isBuildableWith(value, type))
- return builder.create<complex::ConstantOp>(loc, type,
- llvm::cast<ArrayAttr>(value));
+ return complex::ConstantOp::create(builder, loc, type,
+ llvm::cast<ArrayAttr>(value));
return nullptr;
}
@@ -107,7 +107,7 @@ FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
// Create empty tensor.
Value emptyTensor =
- b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
+ tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
return emptyTensor;
}
@@ -678,8 +678,8 @@ FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
inputShapes.emplace_back(std::move(inputShape));
}
- Value replacement = builder.create<tensor::EmptyOp>(
- loc, outputShape, getType().getElementType());
+ Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
+ getType().getElementType());
int64_t rank = getType().getRank();
OpFoldResult one = builder.getIndexAttr(1);
@@ -687,12 +687,12 @@ FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
SmallVector<OpFoldResult> offsets(rank, zero);
for (auto [index, input] : llvm::enumerate(getInputs())) {
offsets[concatDim] = concatOffsets[index];
- auto insertSlice = builder.create<tensor::InsertSliceOp>(
- loc, input, replacement, offsets, inputShapes[index], strides);
+ auto insertSlice = tensor::InsertSliceOp::create(
+ builder, loc, input, replacement, offsets, inputShapes[index], strides);
replacement = insertSlice.getResult();
}
if (replacement.getType() != getType()) {
- replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
+ replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
}
return SmallVector<Value>{replacement};
}
@@ -723,7 +723,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
builder.getIndexAttr(inferredResultType.getDimSize(i)));
} else {
reifiedReturnShapes[0][i] =
- builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
+ tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
}
}
@@ -823,8 +823,8 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
// Use refined operand type and create cast from original operand.
auto castOp =
- rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
- concatOp.getOperand(operandIdx));
+ CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
+ concatOp.getOperand(operandIdx));
rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
concatOp->setOperand(operandIdx, castOp->getResult(0));
});
@@ -864,8 +864,9 @@ struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
return failure();
}
- auto newConcatOp = rewriter.create<ConcatOp>(
- concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+ auto newConcatOp =
+ ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
+ concatOp->getOperands());
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
newConcatOp);
@@ -892,7 +893,7 @@ void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) {
auto loc = result.location;
- Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
+ Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
build(builder, result, source, indexValue);
}
@@ -1036,10 +1037,10 @@ struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
- rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
- rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
+ arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
@@ -1150,8 +1151,8 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
if (foldedTensorType == op.getType())
return failure();
- auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
- foldedDynamicSizes);
+ auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
+ foldedDynamicSizes);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
@@ -1326,8 +1327,8 @@ struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
SmallVector<int64_t> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
- auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
- extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
+ auto delinearize = affine::AffineDelinearizeIndexOp::create(
+ rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseOp.getReassociationIndices().empty()) {
@@ -1498,8 +1499,8 @@ struct ExtractElementFromIndexCast
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
- auto newExtract = rewriter.create<tensor::ExtractOp>(
- loc, elementTy, indexCast.getIn(), extract.getIndices());
+ auto newExtract = tensor::ExtractOp::create(
+ rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
newExtract);
@@ -1736,7 +1737,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
auto loc = generateOp.getLoc();
auto newOp =
- rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
+ GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
newOp.getBody().begin());
rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
@@ -2161,9 +2162,9 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
});
} else {
- auto newOp = rewriter.create<CollapseShapeOp>(
- collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
- collapseShapeOp.getReassociation());
+ auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
+ newResultType, castOp.getSource(),
+ collapseShapeOp.getReassociation());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
collapseShapeOp, collapseShapeOp.getResultType(), newOp);
}
@@ -2240,10 +2241,10 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
newInputShape, expandOp.getSrcType().getElementType());
auto outputType = RankedTensorType::get(
newOutputShape, expandOp.getSrcType().getElementType());
- auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
- expandOp.getSrc());
- auto newExpand = rewriter.create<ExpandShapeOp>(
- expandOp.getLoc(), outputType, inputCast.getResult(),
+ auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
+ expandOp.getSrc());
+ auto newExpand = ExpandShapeOp::create(
+ rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
expandOp.getReassociationIndices(), outputOfr);
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
newExpand.getResult());
@@ -2555,10 +2556,11 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
// Create folded extract.
Location loc = sliceOp.getLoc();
- Value newResult = rewriter.create<ExtractSliceOp>(
- loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
- sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
- sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
+ Value newResult = ExtractSliceOp::create(
+ rewriter, loc, sliceOp.getType(), castOp.getSource(),
+ sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
+ sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
+ sliceOp.getStaticStrides());
rewriter.replaceOp(sliceOp, newResult);
return success();
}
@@ -2709,8 +2711,8 @@ struct SliceCanonicalizer {
ExtractSliceOp newOp) {
Value replacement = newOp.getResult();
if (replacement.getType() != op.getType())
- replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
- replacement);
+ replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
+ replacement);
rewriter.replaceOp(op, replacement);
}
};
@@ -2978,8 +2980,8 @@ class InsertSliceOpConstantArgumentFolder final
// the parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
- toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
- sourceType, toInsert);
+ toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
+ sourceType, toInsert);
}
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
@@ -3075,17 +3077,18 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
if (!sliceResult.isValid)
return failure();
- Operation *replacement = rewriter.create<InsertOpTy>(
- insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
- mixedSizes, insertSliceOp.getMixedStrides());
+ Operation *replacement =
+ InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
+ insertSliceOp.getMixedOffsets(), mixedSizes,
+ insertSliceOp.getMixedStrides());
// In the parallel case there is no result and so nothing to cast.
bool isParallelInsert =
std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
- replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
- insertSliceOp.getDestType(),
- replacement->getResult(0));
+ replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
+ insertSliceOp.getDestType(),
+ replacement->getResult(0));
}
rewriter.replaceOp(insertSliceOp, replacement->getResults());
return success();
@@ -3154,8 +3157,8 @@ struct InsertSliceOpSourceCastInserter final
// parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
- Value cast = rewriter.create<tensor::CastOp>(
- insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
+ Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
+ newSrcType, insertSliceOp.getSource());
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
@@ -3353,7 +3356,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
// a guard to reset the insertion point of the builder after it is destroyed.
OpBuilder::InsertionGuard guard(b);
b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
- b.create<tensor::YieldOp>(result.location, constantPadValue);
+ tensor::YieldOp::create(b, result.location, constantPadValue);
}
llvm::SmallBitVector PadOp::getPaddedDims() {
@@ -3407,10 +3410,11 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
padTensorOp.getSourceMutable().assign(castOp.getSource());
});
} else {
- auto newOp = rewriter.create<PadOp>(
- padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
- padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
- padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+ auto newOp = PadOp::create(
+ rewriter, padTensorOp->getLoc(), newResultType,
+ padTensorOp.getSource(), padTensorOp.getStaticLow(),
+ padTensorOp.getStaticHigh(), padTensorOp.getLow(),
+ padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
@@ -3439,8 +3443,8 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
tensorCastOp.getDest().getType()))
return failure();
- auto replacementOp = rewriter.create<PadOp>(
- padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
+ auto replacementOp = PadOp::create(
+ rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
padTensorOp.getSource(), padTensorOp.getStaticLow(),
padTensorOp.getStaticHigh(), padTensorOp.getLow(),
padTensorOp.getHigh(), padTensorOp.getNofold(),
@@ -3597,11 +3601,11 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
// the two paddings in one step.
- auto newSliceOp = rewriter.create<ExtractSliceOp>(
- padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
- innerSliceOp.getMixedStrides());
- auto newPadOp = rewriter.create<PadOp>(
- padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
+ auto newSliceOp = ExtractSliceOp::create(
+ rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
+ newSizes, innerSliceOp.getMixedStrides());
+ auto newPadOp = PadOp::create(
+ rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
@@ -3697,9 +3701,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
// Rewrite the op using the new static type.
auto newResultType = RankedTensorType::get(
newOutDims, padTensorOp.getType().getElementType());
- auto newOp = rewriter.create<PadOp>(
- padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
- newLows, newHighs, padTensorOp.getNofold(),
+ auto newOp = PadOp::create(
+ rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
+ staticHigh, newLows, newHighs, padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
@@ -3777,9 +3781,9 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
SmallVector<OpFoldResult> newLowPad =
addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
- newLowPad, newHighPad, padOp.getNofold(),
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, padOp.getLoc(), padOp.getResultType(),
+ producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
@@ -3803,7 +3807,7 @@ PadOp::reifyResultShapes(OpBuilder &b,
}
Location loc = getLoc();
Value dim = b.createOrFold<tensor::DimOp>(
- loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+ loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
AffineExpr d0, d1, d2;
bindDims(b.getContext(), d0, d1, d2);
@@ -4108,8 +4112,8 @@ struct FoldTensorCastProducerOp
for (auto [oldResult, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
if (newResult.getType() != oldResult.getType()) {
- replacements.push_back(rewriter.create<tensor::CastOp>(
- op->getLoc(), oldResult.getType(), newResult));
+ replacements.push_back(tensor::CastOp::create(
+ rewriter, op->getLoc(), oldResult.getType(), newResult));
} else {
replacements.push_back(newResult);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 437bc5d00faa8..124a63281a37c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -207,13 +207,13 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
if (isZeroInteger(new...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/149930
More information about the Mlir-commits
mailing list