[Mlir-commits] [mlir] [mlir][NFC] update `mlir/Dialect` create APIs (32/n) (PR #150657)
Maksim Levental
llvmlistbot at llvm.org
Fri Jul 25 10:16:19 PDT 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/150657
See https://github.com/llvm/llvm-project/pull/147168 for more info.
>From e91af0b056bbbd6f32bd0817d086830b7f7dcc6d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 25 Jul 2025 13:15:38 -0400
Subject: [PATCH] [mlir][NFC] update `mlir/Dialect` create APIs (32/n)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +--
.../TransformOps/LinalgTransformOps.cpp | 3 +-
.../Transforms/DataLayoutPropagation.cpp | 3 +-
.../Linalg/Transforms/DropUnitDims.cpp | 3 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 6 +--
.../Transforms/PackAndUnpackPatterns.cpp | 3 +-
.../lib/Dialect/Linalg/Transforms/Padding.cpp | 6 +--
.../Dialect/Linalg/Transforms/Transforms.cpp | 3 +-
.../Linalg/Transforms/TransposeConv2D.cpp | 3 +-
.../Linalg/Transforms/Vectorization.cpp | 3 +-
.../Linalg/Transforms/WinogradConv2D.cpp | 24 ++++--------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +-
.../Vector/Transforms/LowerVectorGather.cpp | 7 ++--
.../Vector/Transforms/LowerVectorTransfer.cpp | 8 ++--
.../Vector/Transforms/VectorDistribute.cpp | 39 +++++++++----------
.../Transforms/VectorDropLeadUnitDim.cpp | 6 +--
.../Transforms/VectorEmulateNarrowType.cpp | 26 ++++++-------
.../Transforms/VectorTransferOpTransforms.cpp | 4 +-
.../VectorTransferSplitRewritePatterns.cpp | 35 ++++++++---------
.../Vector/Transforms/VectorTransforms.cpp | 3 +-
20 files changed, 82 insertions(+), 112 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4fee81aa2ef67..b154c69d28148 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -791,8 +791,7 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
padOp.getResultType().getElementType());
Value replacement =
- rewriter
- .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
+ FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
ValueRange{emptyTensor})
.getResult(0);
if (replacement.getType() != padOp.getResultType()) {
@@ -2154,8 +2153,7 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
// Create broadcast(transpose(input)).
Value transposeResult =
- rewriter
- .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
+ TransposeOp::create(rewriter, loc, broadcastOp.getInput(), transposeInit,
resultPerms)
->getResult(0);
rewriter.replaceOpWithNewOp<BroadcastOp>(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bb0861340ad92..6625267f07d68 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4133,8 +4133,7 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
Value extracted = tensor::ExtractSliceOp::create(
rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
- Value copied = rewriter
- .create<linalg::CopyOp>(target.getLoc(),
+ Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
target.getSource(), extracted)
.getResult(0);
// Reset the insertion point.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 91a297f7b9db7..6dc5bf3a15da4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1143,8 +1143,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// Insert an unPackOp right after the packed generic.
Value unPackOpRes =
- rewriter
- .create<linalg::UnPackOp>(genericOp.getLoc(), newResult,
+ linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
destPack.getSource(), innerDimsPos,
mixedTiles, outerDimsPerm)
.getResult();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 745a40dbc4eea..d3af23b62215d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -267,8 +267,7 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
assert(rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
- return rewriter
- .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+ return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result, reassociation)
.getResult();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4a66b8b9619f4..92342abcc5af3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1572,12 +1572,10 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
// Insert a reshape to collapse the dimensions.
if (isa<MemRefType>(operand.getType())) {
- return builder
- .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ return memref::CollapseShapeOp::create(builder, loc, operand, operandReassociation)
.getResult();
}
- return builder
- .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ return tensor::CollapseShapeOp::create(builder, loc, operand, operandReassociation)
.getResult();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index a45a4e314e511..091266e49db4a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -81,8 +81,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
ArrayRef<ReassociationIndices> reassociation) const {
if (operand.getType() == newOperandType)
return operand;
- return rewriter
- .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+ return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
reassociation)
.getResult();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index b5c5aea56a998..e4182b1451751 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -333,16 +333,14 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
for (auto it :
llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
- replacements.push_back(rewriter
- .create<linalg::CopyOp>(loc, std::get<0>(it),
+ replacements.push_back(linalg::CopyOp::create(rewriter, loc, std::get<0>(it),
std::get<1>(it).get())
.getResult(0));
} else if (options.copyBackOp ==
LinalgPaddingOptions::CopyBackOp::
BufferizationMaterializeInDestination) {
replacements.push_back(
- rewriter
- .create<bufferization::MaterializeInDestinationOp>(
+ bufferization::MaterializeInDestinationOp::create(rewriter,
loc, std::get<0>(it), std::get<1>(it).get())
->getResult(0));
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 1f1e617738981..475b0f94779c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -947,8 +947,7 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
- return rewriter
- .create<arith::ConstantIndexOp>(
+ return arith::ConstantIndexOp::create(rewriter,
padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
.getResult();
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
index 99fb8c796cf06..20fb22334dd38 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -70,8 +70,7 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy)
.getResult();
} else {
- input = rewriter
- .create<memref::AllocOp>(
+ input = memref::AllocOp::create(rewriter,
loc, MemRefType::get(newFilterShape, elementTy))
.getResult();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae627da5445a8..4733d617f0dd4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3714,8 +3714,7 @@ struct Conv1DGenerator
}
}
- return rewriter
- .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
+ return vector::TransferWriteOp::create(rewriter, loc, res, resShaped, resPadding)
.getOperation();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 669fefcd86de1..da8ff88ccebfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -399,8 +399,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
retRows = GMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ tensor::EmptyOp::create(builder, loc, matmulType.getShape(), elementType)
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -423,8 +422,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
auto matmulType =
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ tensor::EmptyOp::create(builder, loc, matmulType.getShape(), elementType)
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -548,8 +546,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retRows = BTMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ tensor::EmptyOp::create(builder, loc, matmulType.getShape(), elementType)
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -573,8 +570,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
retCols = BMatrix.cols;
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
auto empty =
- builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
+ tensor::EmptyOp::create(builder, loc, matmulType.getShape(), elementType)
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -661,8 +657,7 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
outputElementType);
- Value empty = rewriter
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
outputElementType)
.getResult();
Value zero = arith::ConstantOp::create(
@@ -782,8 +777,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
Value init = outInitVal;
if (rightTransform || scalarFactor != 1) {
- auto empty = builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
elementType)
.getResult();
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -802,8 +796,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
Value init = outInitVal;
if (scalarFactor != 1) {
- auto empty = builder
- .create<tensor::EmptyOp>(loc, matmulType.getShape(),
+ auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
elementType)
.getResult();
init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
@@ -827,8 +820,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
matmulRetValue =
- rewriter
- .create<linalg::GenericOp>(
+ linalg::GenericOp::create(rewriter,
loc, matmulType,
ValueRange{scalarFactorValue, matmulRetValue},
ValueRange{outInitVal}, affineMaps,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4e9f93b9cae6f..1a3f972a43fce 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -372,8 +372,7 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
llvm::transform(foldResults, std::back_inserter(values),
[&](OpFoldResult foldResult) {
if (auto attr = dyn_cast<Attribute>(foldResult))
- return builder
- .create<arith::ConstantIndexOp>(
+ return arith::ConstantIndexOp::create(builder,
loc, cast<IntegerAttr>(attr).getInt())
.getResult();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 2484670c39caa..e062f55f87679 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -248,11 +248,10 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
scf::YieldOp::create(b, loc, result);
};
- result =
- rewriter
- .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
+ result = scf::IfOp::create(rewriter, loc, condition,
+ /*thenBuilder=*/loadBuilder,
/*elseBuilder=*/passThruBuilder)
- .getResult(0);
+ .getResult(0);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index e9109322ed3d8..4baeb1145d25b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -142,8 +142,8 @@ struct TransferReadPermutationLowering
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- return rewriter
- .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+ return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
+ transposePerm)
.getResult();
}
};
@@ -371,8 +371,8 @@ struct TransferOpReduceRank
rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
- return rewriter
- .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
+ newRead)
.getVector();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 58e94ea00189f..bb0f339a26e43 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -451,10 +451,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
}
SmallVector<Value> delinearized;
if (map.getNumResults() > 1) {
- delinearized = rewriter
- .create<mlir::affine::AffineDelinearizeIndexOp>(
- newWarpOp.getLoc(), newWarpOp.getLaneid(),
- delinearizedIdSizes)
+ delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
+ rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
+ delinearizedIdSizes)
.getResults();
} else {
// If there is only one map result, we can elide the delinearization
@@ -1538,19 +1537,18 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
newWarpOp.getLaneid(), insertingLane);
Value newResult =
- rewriter
- .create<scf::IfOp>(
- loc, isInsertingLane,
- /*thenBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- Value newInsert = vector::InsertOp::create(
- builder, loc, newSource, distributedVec, newPos);
- scf::YieldOp::create(builder, loc, newInsert);
- },
- /*elseBuilder=*/
- [&](OpBuilder &builder, Location loc) {
- scf::YieldOp::create(builder, loc, distributedVec);
- })
+ scf::IfOp::create(
+ rewriter, loc, isInsertingLane,
+ /*thenBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ Value newInsert = vector::InsertOp::create(
+ builder, loc, newSource, distributedVec, newPos);
+ scf::YieldOp::create(builder, loc, newInsert);
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &builder, Location loc) {
+ scf::YieldOp::create(builder, loc, distributedVec);
+ })
.getResult(0);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
@@ -1661,10 +1659,9 @@ struct WarpOpInsert : public WarpDistributionPattern {
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
scf::YieldOp::create(builder, loc, distributedDest);
};
- newResult = rewriter
- .create<scf::IfOp>(loc, isInsertingLane,
- /*thenBuilder=*/insertingBuilder,
- /*elseBuilder=*/nonInsertingBuilder)
+ newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
+ /*thenBuilder=*/insertingBuilder,
+ /*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 73388a5da3e4f..9889d7f221fe6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -466,9 +466,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
}
- return rewriter
- .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
- newOp->getResults()[0])
+ return vector::BroadcastOp::create(rewriter, loc,
+ contractOp->getResultTypes()[0],
+ newOp->getResults()[0])
.getResult();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 91a484f7d463c..f78e579d6c099 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -132,17 +132,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>(
- [&](auto constantMaskOp) -> std::optional<Operation *> {
- // Take the shape of mask, compress its trailing dimension:
- SmallVector<int64_t> maskDimSizes(
- constantMaskOp.getMaskDimSizes());
- int64_t &maskIndex = maskDimSizes.back();
- maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
- numSrcElemsPerDest);
- return vector::ConstantMaskOp::create(
- rewriter, loc, newMaskType, maskDimSizes);
- })
+ .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
+ -> std::optional<Operation *> {
+ // Take the shape of mask, compress its trailing dimension:
+ SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes());
+ int64_t &maskIndex = maskDimSizes.back();
+ maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
+ numSrcElemsPerDest);
+ return vector::ConstantMaskOp::create(rewriter, loc, newMaskType,
+ maskDimSizes);
+ })
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
@@ -229,9 +228,8 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto resultVectorType =
VectorType::get({numElemsToExtract}, vectorType.getElementType());
- return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
- offsets, sizes, strides)
+ return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType,
+ src, offsets, sizes, strides)
->getResult(0);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2676d254c9b64..48d680c03489b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -330,8 +330,8 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
}
reducedOperands.push_back(operand);
}
- return rewriter
- .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
+ return vector::CreateMaskOp::create(rewriter, loc, reducedType,
+ reducedOperands)
.getResult();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 05b00744beea2..5e12dc486e595 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -348,24 +348,23 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
Location loc = xferOp.getLoc();
Value zero = arith::ConstantIndexOp::create(b, loc, 0);
Value memref = xferOp.getBase();
- return b
- .create<scf::IfOp>(
- loc, inBoundsCond,
- [&](OpBuilder &b, Location loc) {
- Value res =
- castToCompatibleMemRefType(b, memref, compatibleMemRefType);
- scf::ValueVector viewAndIndices{res};
- llvm::append_range(viewAndIndices, xferOp.getIndices());
- scf::YieldOp::create(b, loc, viewAndIndices);
- },
- [&](OpBuilder &b, Location loc) {
- Value casted =
- castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
- scf::ValueVector viewAndIndices{casted};
- viewAndIndices.insert(viewAndIndices.end(),
- xferOp.getTransferRank(), zero);
- scf::YieldOp::create(b, loc, viewAndIndices);
- })
+ return scf::IfOp::create(
+ b, loc, inBoundsCond,
+ [&](OpBuilder &b, Location loc) {
+ Value res =
+ castToCompatibleMemRefType(b, memref, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{res};
+ llvm::append_range(viewAndIndices, xferOp.getIndices());
+ scf::YieldOp::create(b, loc, viewAndIndices);
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value casted =
+ castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
+ scf::ValueVector viewAndIndices{casted};
+ viewAndIndices.insert(viewAndIndices.end(),
+ xferOp.getTransferRank(), zero);
+ scf::YieldOp::create(b, loc, viewAndIndices);
+ })
->getResults();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 73ca327bb49c5..647fe8c78d9da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -410,8 +410,7 @@ FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
VectorType maskOpType =
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
- mask = rewriter
- .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
+ mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(), maskOpType,
maskingOp.getMask())
.getResult();
}
More information about the Mlir-commits
mailing list