[Mlir-commits] [mlir] f904cdd - [mlir][NFC] update `mlir/Dialect` create APIs (24/n) (#149931)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 22 05:16:18 PDT 2025
Author: Maksim Levental
Date: 2025-07-22T08:16:15-04:00
New Revision: f904cdd6c3049e605d24ed17680e80e7133908a0
URL: https://github.com/llvm/llvm-project/commit/f904cdd6c3049e605d24ed17680e80e7133908a0
DIFF: https://github.com/llvm/llvm-project/commit/f904cdd6c3049e605d24ed17680e80e7133908a0.diff
LOG: [mlir][NFC] update `mlir/Dialect` create APIs (24/n) (#149931)
See https://github.com/llvm/llvm-project/pull/147168 for more info.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 56f748fbbe1d6..4c00fb58e4d30 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -124,7 +124,7 @@ static MaskFormat getMaskFormat(Value mask) {
/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
- builder.create<vector::YieldOp>(loc);
+ vector::YieldOp::create(builder, loc);
}
// Helper for verifying combining kinds in contractions and reductions.
@@ -596,16 +596,16 @@ struct ElideUnitDimsInMultiDimReduction
VectorType newMaskType =
VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
dstVecType.getScalableDims());
- mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+ mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
}
- cast = rewriter.create<vector::ShapeCastOp>(
- loc, reductionOp.getDestType(), reductionOp.getSource());
+ cast = vector::ShapeCastOp::create(
+ rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
} else {
// This means we are reducing all the dimensions, and all reduction
// dimensions are of size 1. So a simple extraction would do.
if (mask)
- mask = rewriter.create<vector::ExtractOp>(loc, mask);
- cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
+ mask = vector::ExtractOp::create(rewriter, loc, mask);
+ cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
}
Value result =
@@ -672,36 +672,36 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
switch (op) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::addi:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::ADD, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::ADD, vector);
case arith::AtomicRMWKind::mulf:
case arith::AtomicRMWKind::muli:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MUL, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MUL, vector);
case arith::AtomicRMWKind::minimumf:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINIMUMF, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINIMUMF, vector);
case arith::AtomicRMWKind::mins:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINSI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINSI, vector);
case arith::AtomicRMWKind::minu:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MINUI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MINUI, vector);
case arith::AtomicRMWKind::maximumf:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXIMUMF, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXIMUMF, vector);
case arith::AtomicRMWKind::maxs:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXSI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXSI, vector);
case arith::AtomicRMWKind::maxu:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::MAXUI, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::MAXUI, vector);
case arith::AtomicRMWKind::andi:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::AND, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::AND, vector);
case arith::AtomicRMWKind::ori:
- return builder.create<vector::ReductionOp>(vector.getLoc(),
- CombiningKind::OR, vector);
+ return vector::ReductionOp::create(builder, vector.getLoc(),
+ CombiningKind::OR, vector);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
@@ -740,8 +740,8 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
Location loc = reductionOp.getLoc();
if (mask)
- mask = rewriter.create<ExtractOp>(loc, mask);
- Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
+ mask = ExtractOp::create(rewriter, loc, mask);
+ Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -4172,9 +4172,9 @@ class StridedSliceCreateMaskFolder final
// greater than the vector dim size.
IntegerAttr offsetAttr =
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
- Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+ Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
Value sliceMaskDimSize =
- rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+ arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
sliceMaskDimSizes.push_back(sliceMaskDimSize);
}
// Add unchanged dimensions.
@@ -4289,8 +4289,8 @@ class StridedSliceBroadcast final
sizes[i] = 1;
}
}
- source = rewriter.create<ExtractStridedSliceOp>(
- op->getLoc(), source, offsets, sizes,
+ source = ExtractStridedSliceOp::create(
+ rewriter, op->getLoc(), source, offsets, sizes,
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
@@ -4382,8 +4382,8 @@ class ContiguousExtractStridedSliceToExtract final
SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
- Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
- extractOffsets);
+ Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
+ extractOffsets);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
@@ -4413,7 +4413,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding, /*mask=*/Value(), inBoundsAttr);
}
@@ -4431,7 +4431,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
}
@@ -4450,7 +4450,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
- padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ padding = ub::PoisonOp::create(builder, result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
*padding,
/*mask=*/Value(), inBoundsAttr);
@@ -4975,7 +4975,7 @@ struct TransferReadAfterWriteToBroadcast
VectorType broadcastedType = VectorType::get(
broadcastShape, defWrite.getVectorType().getElementType(),
broadcastScalableFlags);
- vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
+ vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
transposePerm);
@@ -5453,13 +5453,14 @@ struct SwapExtractSliceOfTransferWrite
// Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
// Set all in_bounds to false and let the folder infer them.
SmallVector<bool> newInBounds(vectorShape.size(), false);
- auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
- extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
- insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
- insertOp.getMixedStrides());
- auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
- transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
- transferOp.getIndices(), transferOp.getPermutationMapAttr(),
+ auto newExtractOp = tensor::ExtractSliceOp::create(
+ rewriter, extractOp.getLoc(), insertOp.getSourceType(),
+ insertOp.getDest(), insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ auto newTransferWriteOp = TransferWriteOp::create(
+ rewriter, transferOp.getLoc(), transferOp.getVector(),
+ newExtractOp.getResult(), transferOp.getIndices(),
+ transferOp.getPermutationMapAttr(),
rewriter.getBoolArrayAttr(newInBounds));
rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
@@ -6983,7 +6984,7 @@ void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) {
OpBuilder opBuilder(builder.getContext());
Operation *maskedOp = &block.front();
opBuilder.setInsertionPointToEnd(&block);
- opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
+ vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
}
LogicalResult MaskOp::verify() {
@@ -7318,7 +7319,7 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
// Create a block and move the op to that block.
insBlock->getOperations().splice(
insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
- builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
+ YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
}
/// Creates a vector.mask operation around a maskable operation. Returns the
@@ -7330,12 +7331,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
if (!mask)
return maskableOp;
if (passthru)
- return builder.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, passthru,
- maskableOp, createMaskOpRegion);
- return builder.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, maskableOp,
- createMaskOpRegion);
+ return MaskOp::create(builder, maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, passthru,
+ maskableOp, createMaskOpRegion);
+ return MaskOp::create(builder, maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, maskableOp,
+ createMaskOpRegion);
}
/// Creates a vector select operation that picks values from `newValue` or
@@ -7350,8 +7351,8 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
if (!mask)
return newValue;
- return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
- mask, newValue, passthru);
+ return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
+ mask, newValue, passthru);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 9da051150e409..66196194b0585 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -116,8 +116,8 @@ struct TransferWriteOpInterface
getBuffer(rewriter, writeOp.getBase(), options, state);
if (failed(resultBuffer))
return failure();
- rewriter.create<vector::TransferWriteOp>(
- writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
+ vector::TransferWriteOp::create(
+ rewriter, writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
writeOp.getMask(), writeOp.getInBoundsAttr());
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
@@ -241,8 +241,9 @@ struct MaskOpInterface
// Create a new vector.mask op.
ValueRange newYieldedValuesRange(newYieldedValues);
TypeRange newResultTypes(newYieldedValuesRange);
- auto newOp = rewriter.create<vector::MaskOp>(
- op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
+ auto newOp = vector::MaskOp::create(
+ rewriter, op->getLoc(), newResultTypes, maskOp.getMask(),
+ maskOp.getPassthru(),
/*maskableOp=*/nullptr,
/*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
newOp.getRegion().takeBody(maskOp.getMaskRegion());
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 89930a6bd35fa..4c3a04cfb5bfa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -64,14 +64,14 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
VectorType::get(shape, resultType.getElementType(), scalableDims);
Location loc = op.getLoc();
- Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resultType);
for (auto position : *unrollIterator) {
Value extract =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
Value bitcast =
- rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+ vector::BitCastOp::create(rewriter, loc, bitcastResType, extract);
result =
- rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+ vector::InsertOp::create(rewriter, loc, bitcast, result, position);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 11dcfe421e0c4..cb8e566869cfd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -52,7 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
@@ -70,10 +70,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// Duplication.
VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
- rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
- Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
+ vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
+ Value result = ub::PoisonOp::create(rewriter, loc, dstType);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
rewriter.replaceOp(op, result);
return success();
}
@@ -111,13 +111,13 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType,
dstType.getScalableDims().drop_front());
- Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
+ Value result = ub::PoisonOp::create(rewriter, loc, dstType);
if (m == 0) {
// Stetch at start.
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
+ Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
} else {
// Stetch not at start.
if (dstType.getScalableDims()[0]) {
@@ -125,9 +125,9 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
return failure();
}
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
- Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
- result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
+ Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
+ result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
}
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index fc6c90f5132c7..65702ffa152d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -81,17 +81,17 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
// At extraction dimension?
if (index == 0)
- return rewriter.create<vector::ExtractOp>(loc, val, pos);
+ return vector::ExtractOp::create(rewriter, loc, val, pos);
// Unroll leading dimensions.
VectorType vType = VectorType::Builder(type).dropDim(0);
VectorType resType = VectorType::Builder(type).dropDim(index);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
+ Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, load, result, d);
+ result = vector::InsertOp::create(rewriter, loc, load, result, d);
}
return result;
}
@@ -106,15 +106,15 @@ static Value reshapeStore(Location loc, Value val, Value result,
return val;
// At insertion dimension?
if (index == 0)
- return rewriter.create<vector::InsertOp>(loc, val, result, pos);
+ return vector::InsertOp::create(rewriter, loc, val, result, pos);
// Unroll leading dimensions.
VectorType vType = VectorType::Builder(type).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
- Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
- Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
+ Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
+ Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
+ result = vector::InsertOp::create(rewriter, loc, sto, result, d);
}
return result;
}
@@ -132,7 +132,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
// Only valid for floating point types.
return std::nullopt;
- mul = rewriter.create<arith::MulIOp>(loc, x, y);
+ mul = arith::MulIOp::create(rewriter, loc, x, y);
} else {
// Float case.
if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
@@ -143,14 +143,14 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
return std::nullopt;
// Special case for fused multiply-add.
if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
- Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
if (mask)
// The fma op doesn't need explicit masking. However, fma ops used in
// reductions must preserve previous 'acc' values for masked-out lanes.
fma = selectPassthru(rewriter, mask, fma, acc);
return fma;
}
- mul = rewriter.create<arith::MulFOp>(loc, x, y);
+ mul = arith::MulFOp::create(rewriter, loc, x, y);
}
if (!acc)
@@ -186,8 +186,8 @@ static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
static Value createAdd(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
- return rewriter.create<arith::AddIOp>(loc, x, y);
- return rewriter.create<arith::AddFOp>(loc, x, y);
+ return arith::AddIOp::create(rewriter, loc, x, y);
+ return arith::AddFOp::create(rewriter, loc, x, y);
}
/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
@@ -195,8 +195,8 @@ static Value createAdd(Location loc, Value x, Value y, bool isInt,
static Value createMul(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
- return rewriter.create<arith::MulIOp>(loc, x, y);
- return rewriter.create<arith::MulFOp>(loc, x, y);
+ return arith::MulIOp::create(rewriter, loc, x, y);
+ return arith::MulFOp::create(rewriter, loc, x, y);
}
namespace {
@@ -359,7 +359,7 @@ struct UnrolledOuterProductGenerator
Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
if (!v)
return v;
- return rewriter.create<vector::TransposeOp>(loc, v, perm);
+ return vector::TransposeOp::create(rewriter, loc, v, perm);
}
Value promote(Value v, Type dstElementType) {
@@ -373,8 +373,8 @@ struct UnrolledOuterProductGenerator
if (vecType)
promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
- return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
- return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
+ return arith::ExtFOp::create(rewriter, loc, promotedType, v);
+ return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
@@ -386,17 +386,17 @@ struct UnrolledOuterProductGenerator
Type resElementType = cast<VectorType>(res.getType()).getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
- Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
- Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+ Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
+ Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
extractA = promote(extractA, resElementType);
extractB = promote(extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
- rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+ vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
- Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
- loc, res.getType(), extractA, extractB, res, kind);
+ Operation *outerProdOp = vector::OuterProductOp::create(
+ rewriter, loc, res.getType(), extractA, extractB, res, kind);
res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
}
return res;
@@ -646,28 +646,28 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
// Two outer parallel, one inner reduction (matmat flavor).
//
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
// No need to permute anything.
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
- rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
// This is the classical row-major matmul. Just permute the lhs.
Value tmp = lhs;
- lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
rhs = tmp;
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
Value tmp = lhs;
- lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
- rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
Value tmp = rhs;
- rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
lhs = tmp;
} else {
return failure();
@@ -680,12 +680,12 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
if (maps == infer({{m, n}, {n}, {m}})) {
// No need to permute anything.
} else if (maps == infer({{n, m}, {n}, {m}})) {
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else if (maps == infer({{n}, {m, n}, {m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{n}, {n, m}, {m}})) {
std::swap(lhs, rhs);
- lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
} else {
return failure();
}
@@ -702,31 +702,32 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
- Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
- rewriter.getZeroAttr(dstType));
+ Value res = arith::ConstantOp::create(rewriter, loc, dstType,
+ rewriter.getZeroAttr(dstType));
bool isInt = isa<IntegerType>(dstType.getElementType());
llvm::SmallVector<Value> extractedCols;
extractedCols.reserve(dstColumns);
for (unsigned r = 0; r < dstRows; ++r) {
- Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
+ Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
// Extract each respective row and column of the LHS and RHS once to
// avoid having duplicate SSA values pointing to the same rows/columns.
if (r == 0) {
Value colRhs =
- rank == 1 ? rhs
- : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
+ rank == 1
+ ? rhs
+ : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
extractedCols.push_back(colRhs);
}
Value extractedColRhs = extractedCols[c];
Value product =
createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
- Value sum = rewriter.create<vector::ReductionOp>(
- op.getLoc(), vector::CombiningKind::ADD, product);
+ Value sum = vector::ReductionOp::create(
+ rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
- res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
+ res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
}
}
if (auto acc = op.getAcc())
@@ -827,21 +828,21 @@ struct ContractOpToElementwise
lhsDims.append(lhsShape.begin(), lhsShape.end());
auto expandedType =
VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
- newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
+ newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
}
if (!rhsDims.empty()) {
rhsDims.append(rhsShape.begin(), rhsShape.end());
auto expandedType =
VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
- newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
+ newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
}
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
- newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
- newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
+ newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
+ newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
- newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
- newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
+ newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
+ newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
@@ -1039,8 +1040,8 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
// Unroll into a series of lower dimensional vector.contract ops.
Location loc = op.getLoc();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
@@ -1052,8 +1053,8 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
iterIndex, d, rewriter);
- Operation *lowContract = rewriter.create<vector::ContractionOp>(
- loc, lhs, rhs, acc, lowAffine, lowIter);
+ Operation *lowContract = vector::ContractionOp::create(
+ rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
lowContract = maskOperation(rewriter, lowContract, lowMask);
result = reshapeStore(loc, lowContract->getResult(0), result, resType,
resIndex, d, rewriter);
@@ -1103,8 +1104,8 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
Value acc = op.getAcc();
Operation *reductionOp =
- acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
- : rewriter.create<vector::ReductionOp>(loc, kind, m);
+ acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
+ : vector::ReductionOp::create(rewriter, loc, kind, m);
return maskOperation(rewriter, reductionOp, mask)->getResult(0);
}
// Construct new iterator types and affine map array attribute.
@@ -1128,8 +1129,8 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
iterIndex, d, rewriter);
- Operation *newContract = rewriter.create<vector::ContractionOp>(
- loc, lhs, rhs, result, lowAffine, lowIter);
+ Operation *newContract = vector::ContractionOp::create(
+ rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
result = maskOperation(rewriter, newContract, newMask)->getResult(0);
}
return result;
@@ -1182,7 +1183,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
if (!rhsType) {
// Special case: AXPY operation.
- Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
+ Value b =
+ vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
std::optional<Value> mult = createContractArithOp(
loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
if (!mult.has_value())
@@ -1191,23 +1193,23 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
return success();
}
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
- Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
- Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
+ Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
+ Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
Value r = nullptr;
if (acc)
- r = rewriter.create<vector::ExtractOp>(loc, acc, d);
+ r = vector::ExtractOp::create(rewriter, loc, acc, d);
Value extrMask;
if (mask)
- extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
+ extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
std::optional<Value> m = createContractArithOp(
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
- result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
+ result = vector::InsertOp::create(rewriter, loc, *m, result, d);
}
rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index f4ad56b4178db..2484670c39caa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -68,8 +68,8 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultTy, rewriter.getZeroAttr(resultTy));
+ Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
+ rewriter.getZeroAttr(resultTy));
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
@@ -77,16 +77,16 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
int64_t thisIdx[1] = {i};
Value indexSubVec =
- rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+ vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
Value maskSubVec =
- rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+ vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
- rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
- Value subGather = rewriter.create<vector::GatherOp>(
- loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
- passThruSubVec);
+ vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
+ Value subGather = vector::GatherOp::create(
+ rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
+ maskSubVec, passThruSubVec);
result =
- rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+ vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
}
rewriter.replaceOp(op, result);
@@ -152,24 +152,24 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 1. Collapse the input memref so that it's "flat".
SmallVector<ReassociationIndices> reassoc = {{0, 1}};
- Value collapsed = rewriter.create<memref::CollapseShapeOp>(
- op.getLoc(), subview.getSource(), reassoc);
+ Value collapsed = memref::CollapseShapeOp::create(
+ rewriter, op.getLoc(), subview.getSource(), reassoc);
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
VectorType vType = op.getIndexVec().getType();
- Value mulCst = rewriter.create<arith::ConstantOp>(
- op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
+ Value mulCst = arith::ConstantOp::create(
+ rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
- Value newGather = rewriter.create<vector::GatherOp>(
- op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
- newIdxs, op.getMask(), op.getPassThru());
+ Value newGather = vector::GatherOp::create(
+ rewriter, op.getLoc(), op.getResult().getType(), collapsed,
+ op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -222,8 +222,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
int64_t thisIdx[1] = {i};
Value condition =
- rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
- Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+ vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
+ Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
baseOffsets.back() =
rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
@@ -233,19 +233,19 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
// `vector.load` does not support scalar result; emit a vector load
// and extract the single result instead.
Value load =
- b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
+ vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets);
int64_t zeroIdx[1] = {0};
- extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
+ extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
} else {
- extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
+ extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
}
Value newResult =
- b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
- b.create<scf::YieldOp>(loc, newResult);
+ vector::InsertOp::create(b, loc, extracted, result, thisIdx);
+ scf::YieldOp::create(b, loc, newResult);
};
auto passThruBuilder = [result](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, result);
+ scf::YieldOp::create(b, loc, result);
};
result =
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index cab0f213b14a9..9d6a865a9301f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -60,14 +60,16 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
return failure();
auto loc = op.getLoc();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultType, rewriter.getZeroAttr(resultType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
for (auto position : *unrollIterator) {
- Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
- Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
+ Value extractLhs =
+ ExtractOp::create(rewriter, loc, op.getLhs(), position);
+ Value extractRhs =
+ ExtractOp::create(rewriter, loc, op.getRhs(), position);
Value interleave =
- rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
- result = rewriter.create<InsertOp>(loc, interleave, result, position);
+ InterleaveOp::create(rewriter, loc, extractLhs, extractRhs);
+ result = InsertOp::create(rewriter, loc, interleave, result, position);
}
rewriter.replaceOp(op, result);
@@ -123,20 +125,20 @@ class UnrollDeinterleaveOp final
return failure();
auto loc = op.getLoc();
- Value emptyResult = rewriter.create<arith::ConstantOp>(
- loc, resultType, rewriter.getZeroAttr(resultType));
+ Value emptyResult = arith::ConstantOp::create(
+ rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
Value evenResult = emptyResult;
Value oddResult = emptyResult;
for (auto position : *unrollIterator) {
auto extractSrc =
- rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
auto deinterleave =
- rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
- evenResult = rewriter.create<vector::InsertOp>(
- loc, deinterleave.getRes1(), evenResult, position);
- oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
- oddResult, position);
+ vector::DeinterleaveOp::create(rewriter, loc, extractSrc);
+ evenResult = vector::InsertOp::create(
+ rewriter, loc, deinterleave.getRes1(), evenResult, position);
+ oddResult = vector::InsertOp::create(
+ rewriter, loc, deinterleave.getRes2(), oddResult, position);
}
rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index ba21092d2af3c..45ef7f01a85f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -67,19 +67,20 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
Value idx = op.getOperand(0);
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
- Value trueVal = rewriter.create<vector::CreateMaskOp>(
- loc, lowType, op.getOperands().drop_front());
- Value falseVal = rewriter.create<arith::ConstantOp>(
- loc, lowType, rewriter.getZeroAttr(lowType));
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
+ Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType,
+ op.getOperands().drop_front());
+ Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType,
+ rewriter.getZeroAttr(lowType));
+ Value result = arith::ConstantOp::create(rewriter, loc, dstType,
+ rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < dim; d++) {
Value bnd =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
- Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
- bnd, idx);
- Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
- result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
+ arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(d));
+ Value val = arith::CmpIOp::create(rewriter, loc,
+ arith::CmpIPredicate::slt, bnd, idx);
+ Value sel =
+ arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal);
+ result = vector::InsertOp::create(rewriter, loc, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
@@ -146,12 +147,12 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
op, "Cannot unroll leading scalable dim in dstType");
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
- Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, dimSizes.drop_front());
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
+ Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType,
+ dimSizes.drop_front());
+ Value result = arith::ConstantOp::create(rewriter, loc, dstType,
+ rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
- result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
+ result = vector::InsertOp::create(rewriter, loc, trueVal, result, d);
rewriter.replaceOp(op, result);
return success();
@@ -261,8 +262,8 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
PatternRewriter &rewriter) const override {
Value passthru = maskingOp.hasPassthru()
? maskingOp.getPassthru()
- : rewriter.create<arith::ConstantOp>(
- gatherOp.getLoc(),
+ : arith::ConstantOp::create(
+ rewriter, gatherOp.getLoc(),
rewriter.getZeroAttr(gatherOp.getVectorType()));
// Replace the `vector.mask` operation.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ce524b259d8d4..4773732d8d9a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -103,12 +103,12 @@ class InnerOuterDimReductionConversion
// If masked, transpose the original mask.
Value transposedMask;
if (maskableOp.isMasked()) {
- transposedMask = rewriter.create<vector::TransposeOp>(
- loc, maskableOp.getMaskingOp().getMask(), indices);
+ transposedMask = vector::TransposeOp::create(
+ rewriter, loc, maskableOp.getMaskingOp().getMask(), indices);
}
// Transpose reduction source.
- auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
+ auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices);
SmallVector<bool> reductionMask(srcRank, false);
for (int i = 0; i < reductionSize; ++i) {
if (useInnerDimsForReduction)
@@ -117,8 +117,8 @@ class InnerOuterDimReductionConversion
reductionMask[i] = true;
}
- Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
- multiReductionOp.getLoc(), transposeOp.getResult(),
+ Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
+ rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
newMultiRedOp =
mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
@@ -255,15 +255,15 @@ class ReduceMultiDimReductionRank
auto maskCastedType = VectorType::get(
vectorShape,
llvm::cast<VectorType>(vectorMask.getType()).getElementType());
- newVectorMask =
- rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
+ newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
+ vectorMask);
}
auto castedType = VectorType::get(
vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
scalableDims);
- Value cast = rewriter.create<vector::ShapeCastOp>(
- loc, castedType, multiReductionOp.getSource());
+ Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
+ multiReductionOp.getSource());
Value acc = multiReductionOp.getAcc();
if (flattenedParallelDim) {
@@ -271,12 +271,12 @@ class ReduceMultiDimReductionRank
{flattenedParallelDim},
multiReductionOp.getSourceVectorType().getElementType(),
/*scalableDims=*/{isParallelDimScalable});
- acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
+ acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc);
}
// 6. Creates the flattened form of vector.multi_reduction with inner/outer
// most dim as reduction.
- Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, acc, mask, multiReductionOp.getKind());
+ Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, cast, acc, mask, multiReductionOp.getKind());
newMultiDimRedOp =
mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
@@ -339,11 +339,11 @@ struct TwoDimMultiReductionToElementWise
Value result = multiReductionOp.getAcc();
for (int64_t i = 0; i < srcShape[0]; i++) {
- auto operand = rewriter.create<vector::ExtractOp>(
- loc, multiReductionOp.getSource(), i);
+ auto operand = vector::ExtractOp::create(rewriter, loc,
+ multiReductionOp.getSource(), i);
Value extractMask = nullptr;
if (mask) {
- extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
+ extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
}
result =
makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
@@ -383,28 +383,29 @@ struct TwoDimMultiReductionToReduction
}
auto loc = multiReductionOp.getLoc();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, multiReductionOp.getDestType(),
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
for (int i = 0; i < outerDim; ++i) {
- auto v = rewriter.create<vector::ExtractOp>(
- loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
- auto acc = rewriter.create<vector::ExtractOp>(
- loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
- Operation *reductionOp = rewriter.create<vector::ReductionOp>(
- loc, multiReductionOp.getKind(), v, acc);
+ auto v = vector::ExtractOp::create(
+ rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
+ auto acc = vector::ExtractOp::create(
+ rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
+ Operation *reductionOp = vector::ReductionOp::create(
+ rewriter, loc, multiReductionOp.getKind(), v, acc);
// If masked, slice the mask and mask the new reduction operation.
if (maskableOp.isMasked()) {
- Value mask = rewriter.create<vector::ExtractOp>(
- loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
+ Value mask = vector::ExtractOp::create(
+ rewriter, loc, maskableOp.getMaskingOp().getMask(),
+ ArrayRef<int64_t>{i});
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}
- result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
- result, i);
+ result = vector::InsertOp::create(rewriter, loc,
+ reductionOp->getResult(0), result, i);
}
rewriter.replaceOp(rootOp, result);
@@ -459,10 +460,10 @@ struct OneDimMultiReductionToTwoDim
SmallVector<bool, 2> reductionMask{false, true};
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
- Value cast = rewriter.create<vector::ShapeCastOp>(
- loc, castedType, multiReductionOp.getSource());
- Value castAcc = rewriter.create<vector::BroadcastOp>(
- loc, accType, multiReductionOp.getAcc());
+ Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
+ multiReductionOp.getSource());
+ Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
+ multiReductionOp.getAcc());
Value castMask;
if (maskableOp.isMasked()) {
auto maskType = llvm::cast<VectorType>(mask.getType());
@@ -470,11 +471,12 @@ struct OneDimMultiReductionToTwoDim
ArrayRef<int64_t>{1, maskType.getShape().back()},
maskType.getElementType(),
ArrayRef<bool>{false, maskType.getScalableDims().back()});
- castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
+ castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
}
- Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
- loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
+ Operation *newOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, cast, castAcc, reductionMask,
+ multiReductionOp.getKind());
newOp = vector::maskOperation(rewriter, newOp, castMask);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index 6f3955f522775..af4851eb5f158 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -112,8 +112,8 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
return failure();
VectorType resType = VectorType::get(destShape, elType);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ rewriter.getZeroAttr(resType));
int64_t reductionDim = scanOp.getReductionDim();
bool inclusive = scanOp.getInclusive();
int64_t destRank = destType.getRank();
@@ -134,9 +134,9 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
for (int i = 0; i < destShape[reductionDim]; i++) {
offsets[reductionDim] = i;
ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
- Value input = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
- scanStrides);
+ Value input = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, reductionType, scanOp.getSource(), scanOffsets,
+ scanSizes, scanStrides);
Value output;
if (i == 0) {
if (inclusive) {
@@ -144,11 +144,11 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
} else {
if (initialValueRank == 0) {
// ShapeCastOp cannot handle 0-D vectors
- output = rewriter.create<vector::BroadcastOp>(
- loc, input.getType(), scanOp.getInitialValue());
+ output = vector::BroadcastOp::create(rewriter, loc, input.getType(),
+ scanOp.getInitialValue());
} else {
- output = rewriter.create<vector::ShapeCastOp>(
- loc, input.getType(), scanOp.getInitialValue());
+ output = vector::ShapeCastOp::create(rewriter, loc, input.getType(),
+ scanOp.getInitialValue());
}
}
} else {
@@ -156,20 +156,20 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
lastOutput, y);
}
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, output, result, offsets, strides);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, output,
+ result, offsets, strides);
lastOutput = output;
lastInput = input;
}
Value reduction;
if (initialValueRank == 0) {
- Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
+ Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0);
reduction =
- rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
+ vector::BroadcastOp::create(rewriter, loc, initialValueType, v);
} else {
- reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
- lastOutput);
+ reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType,
+ lastOutput);
}
rewriter.replaceOp(scanOp, {result, reduction});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 39c16fab21c4e..603ea41d43360 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -137,11 +137,12 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
const int64_t resultLeading = delta > 0 ? 0 : -delta;
const Value source = shapeCast.getSource();
- const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
- const Value extracted = rewriter.create<vector::ExtractOp>(
- loc, source, SmallVector<int64_t>(sourceLeading, 0));
- const Value result = rewriter.create<vector::InsertOp>(
- loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
+ const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
+ const Value extracted = vector::ExtractOp::create(
+ rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
+ const Value result =
+ vector::InsertOp::create(rewriter, loc, extracted, poison,
+ SmallVector<int64_t>(resultLeading, 0));
rewriter.replaceOp(shapeCast, result);
return success();
@@ -171,14 +172,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
SmallVector<int64_t> extractIndex(sourceDim, 0);
SmallVector<int64_t> insertIndex(resultDim, 0);
- Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resultType);
for (int i = 0; i < nSlices; ++i) {
Value extracted =
- rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+ vector::ExtractOp::create(rewriter, loc, source, extractIndex);
- result = rewriter.create<vector::InsertOp>(loc, extracted, result,
- insertIndex);
+ result = vector::InsertOp::create(rewriter, loc, extracted, result,
+ insertIndex);
inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
@@ -276,9 +277,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
Value extracted = {};
Value extractedStrided = {};
Value insertedSlice = {};
- Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resultType);
const Value partResult =
- rewriter.create<ub::PoisonOp>(loc, insertStridedType);
+ ub::PoisonOp::create(rewriter, loc, insertStridedType);
for (size_t i = 0; i < nAtomicSlices; ++i) {
@@ -288,28 +289,28 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
// vector.extract
if (extractStridedPhase == 0) {
extracted =
- rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
+ vector::ExtractOp::create(rewriter, loc, source, extractIndex);
inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
extractIndex);
}
// vector.extract_strided_slice
extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
- extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, extracted, extractOffsets, atomicShape, sizes);
+ extractedStrided = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
// vector.insert_strided_slice
if (insertStridedPhase == 0) {
insertedSlice = partResult;
}
insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
- insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
- loc, extractedStrided, insertedSlice, insertOffsets, sizes);
+ insertedSlice = vector::InsertStridedSliceOp::create(
+ rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
// vector.insert
if (insertStridedPhase + 1 == insertPeriod) {
- result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
- insertIndex);
+ result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
+ insertIndex);
inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
insertIndex);
}
@@ -394,7 +395,7 @@ class ScalableShapeCastOpRewritePattern
auto extractionVectorType = VectorType::get(
{minExtractionSize}, sourceVectorType.getElementType(), {true});
- Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);
@@ -406,16 +407,18 @@ class ScalableShapeCastOpRewritePattern
// 1. Extract a scalable subvector from the source vector.
if (!currentSourceScalableVector) {
if (srcRank != 1) {
- currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
- loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
+ currentSourceScalableVector =
+ vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ llvm::ArrayRef(srcIdx).drop_back());
} else {
currentSourceScalableVector = op.getSource();
}
}
Value sourceSubVector = currentSourceScalableVector;
if (minExtractionSize < minSourceTrailingSize) {
- sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
- loc, extractionVectorType, sourceSubVector, srcIdx.back());
+ sourceSubVector = vector::ScalableExtractOp::create(
+ rewriter, loc, extractionVectorType, sourceSubVector,
+ srcIdx.back());
}
// 2. Insert the scalable subvector into the result vector.
@@ -423,15 +426,16 @@ class ScalableShapeCastOpRewritePattern
if (minExtractionSize == minResultTrailingSize) {
currentResultScalableVector = sourceSubVector;
} else if (resRank != 1) {
- currentResultScalableVector = rewriter.create<vector::ExtractOp>(
- loc, result, llvm::ArrayRef(resIdx).drop_back());
+ currentResultScalableVector = vector::ExtractOp::create(
+ rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back());
} else {
currentResultScalableVector = result;
}
}
if (minExtractionSize < minResultTrailingSize) {
- currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
- loc, sourceSubVector, currentResultScalableVector, resIdx.back());
+ currentResultScalableVector = vector::ScalableInsertOp::create(
+ rewriter, loc, sourceSubVector, currentResultScalableVector,
+ resIdx.back());
}
// 3. Update the source and result scalable vectors if needed.
@@ -439,9 +443,9 @@ class ScalableShapeCastOpRewritePattern
currentResultScalableVector != result) {
// Finished row of result. Insert complete scalable vector into result
// (n-D) vector.
- result = rewriter.create<vector::InsertOp>(
- loc, currentResultScalableVector, result,
- llvm::ArrayRef(resIdx).drop_back());
+ result = vector::InsertOp::create(rewriter, loc,
+ currentResultScalableVector, result,
+ llvm::ArrayRef(resIdx).drop_back());
currentResultScalableVector = {};
}
if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 475528289f01f..6407a868abd85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -629,8 +629,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
nextLevelVectorSize);
}
- Value shuffleVal = rewriter.create<vector::ShuffleOp>(
- loc, lhsVector, rhsVector, shuffleMask);
+ Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector,
+ rhsVector, shuffleMask);
levelOutputs.push_back(shuffleVal);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index fb040bc51a993..e9109322ed3d8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -44,7 +44,7 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
originalVecType.getScalableDims().end());
VectorType newVecType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
- return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
+ return vector::BroadcastOp::create(builder, loc, newVecType, vec);
}
/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
@@ -59,7 +59,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
permutation.push_back(i);
for (int64_t i = 0; i < addedRank; ++i)
permutation.push_back(i);
- return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
+ return vector::TransposeOp::create(builder, loc, broadcasted, permutation);
}
//===----------------------------------------------------------------------===//
@@ -135,8 +135,8 @@ struct TransferReadPermutationLowering
// Generate new transfer_read operation.
VectorType newReadType = VectorType::get(
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
- Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+ Value newRead = vector::TransferReadOp::create(
+ rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
@@ -206,12 +206,12 @@ struct TransferWritePermutationLowering
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
// Generate new transfer_write operation.
- Value newVec = rewriter.create<vector::TransposeOp>(
- op.getLoc(), op.getVector(), indices);
+ Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(),
+ op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
- auto newWrite = rewriter.create<vector::TransferWriteOp>(
- op.getLoc(), newVec, op.getBase(), op.getIndices(),
+ auto newWrite = vector::TransferWriteOp::create(
+ rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
@@ -296,8 +296,8 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
- auto newWrite = rewriter.create<vector::TransferWriteOp>(
- op.getLoc(), newVec, op.getBase(), op.getIndices(),
+ auto newWrite = vector::TransferWriteOp::create(
+ rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
@@ -367,8 +367,8 @@ struct TransferOpReduceRank
? rewriter.getArrayAttr(
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
- Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+ Value newRead = vector::TransferReadOp::create(
+ rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
return rewriter
@@ -468,21 +468,21 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = rewriter.create<vector::SplatOp>(
- read.getLoc(), unbroadcastedVectorType, read.getPadding());
- res = rewriter.create<vector::MaskedLoadOp>(
- read.getLoc(), unbroadcastedVectorType, read.getBase(),
+ Value fill = vector::SplatOp::create(
+ rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
+ res = vector::MaskedLoadOp::create(
+ rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
read.getIndices(), read.getMask(), fill);
} else {
- res = rewriter.create<vector::LoadOp>(read.getLoc(),
- unbroadcastedVectorType,
- read.getBase(), read.getIndices());
+ res = vector::LoadOp::create(rewriter, read.getLoc(),
+ unbroadcastedVectorType, read.getBase(),
+ read.getIndices());
}
// Insert a broadcasting op if required.
if (!broadcastedDims.empty())
- res = rewriter.create<vector::BroadcastOp>(
- read.getLoc(), read.getVectorType(), res->getResult(0));
+ res = vector::BroadcastOp::create(
+ rewriter, read.getLoc(), read.getVectorType(), res->getResult(0));
return res->getResult(0);
}
@@ -566,12 +566,12 @@ struct TransferWriteToVectorStoreLowering
<< write;
});
- rewriter.create<vector::MaskedStoreOp>(
- write.getLoc(), write.getBase(), write.getIndices(), write.getMask(),
- write.getVector());
+ vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(),
+ write.getIndices(), write.getMask(),
+ write.getVector());
} else {
- rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
- write.getBase(), write.getIndices());
+ vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(),
+ write.getBase(), write.getIndices());
}
// There's no return value for StoreOps. Use Value() to signal success to
// matchAndRewrite.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index bb9a6832146e8..e14f96e7eec59 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -79,8 +79,8 @@ getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
int numBits) {
int numElem = numBits / 32;
- return b.create<vector::ShuffleOp>(
- v1, v2,
+ return vector::ShuffleOp::create(
+ b, v1, v2,
getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
}
@@ -93,8 +93,8 @@ static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
int numBits) {
int numElem = numBits / 32;
- return b.create<vector::ShuffleOp>(
- v1, v2,
+ return vector::ShuffleOp::create(
+ b, v1, v2,
getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
numBits));
}
@@ -108,8 +108,8 @@ static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
int numBits) {
int numElem = numBits / 32;
- auto shuffle = b.create<vector::ShuffleOp>(
- v1, v2,
+ auto shuffle = vector::ShuffleOp::create(
+ b, v1, v2,
getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
return shuffle;
}
@@ -123,8 +123,8 @@ static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
int numBits) {
int numElem = numBits / 32;
- return b.create<vector::ShuffleOp>(
- v1, v2,
+ return vector::ShuffleOp::create(
+ b, v1, v2,
getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
numBits));
}
@@ -180,7 +180,7 @@ static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
appendToMask(0, b23);
appendToMask(16, b45);
appendToMask(16, b67);
- return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
+ return vector::ShuffleOp::create(b, v1, v2, shuffleMask);
}
/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
@@ -191,7 +191,7 @@ static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
for (int64_t j = 0; j < n; ++j)
for (int64_t i = 0; i < m; ++i)
mask.push_back(i * n + j);
- return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
+ return vector::ShuffleOp::create(b, source.getLoc(), source, source, mask);
}
/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
@@ -283,9 +283,9 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
auto reshInputType = VectorType::get(
{m, n}, cast<VectorType>(source.getType()).getElementType());
- Value res = b.create<ub::PoisonOp>(reshInputType);
+ Value res = ub::PoisonOp::create(b, reshInputType);
for (int64_t i = 0; i < m; ++i)
- res = b.create<vector::InsertOp>(vs[i], res, i);
+ res = vector::InsertOp::create(b, vs[i], res, i);
return res;
}
@@ -343,7 +343,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// of the leftmost transposed dimensions. We traverse every transpose
// element using a linearized index that we delinearize to generate the
// appropriate indices for the extract/insert operations.
- Value result = rewriter.create<ub::PoisonOp>(loc, resType);
+ Value result = ub::PoisonOp::create(rewriter, loc, resType);
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
@@ -466,14 +466,14 @@ class TransposeOp2DToShuffleLowering
Location loc = op.getLoc();
auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
- auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
- op.getVector());
+ auto reshInput = vector::ShapeCastOp::create(rewriter, loc, flattenedType,
+ op.getVector());
Value res;
if (vectorTransposeLowering == VectorTransposeLowering::Shuffle16x16 &&
m == 16 && n == 16) {
reshInput =
- rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
+ vector::ShapeCastOp::create(rewriter, loc, reshInputType, reshInput);
res = transposeToShuffle16x16(rewriter, reshInput, m, n);
} else {
// Fallback to shuffle on 1D approach.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 436029c31e7f8..58e94ea00189f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -114,7 +114,7 @@ struct DistributedLoadStoreHelper {
"preregistered sequential value.");
// Scalar case can directly use memref.store.
if (!isa<VectorType>(val.getType()))
- return b.create<memref::StoreOp>(loc, val, buffer, zero);
+ return memref::StoreOp::create(b, loc, val, buffer, zero);
// Vector case must use vector::TransferWriteOp which will later lower to
// vector.store of memref.store depending on further lowerings.
@@ -127,8 +127,8 @@ struct DistributedLoadStoreHelper {
}
}
SmallVector<bool> inBounds(indices.size(), true);
- return b.create<vector::TransferWriteOp>(
- loc, val, buffer, indices,
+ return vector::TransferWriteOp::create(
+ b, loc, val, buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
@@ -156,7 +156,7 @@ struct DistributedLoadStoreHelper {
// Scalar case can directly use memref.store.
if (!isa<VectorType>(type))
- return b.create<memref::LoadOp>(loc, buffer, zero);
+ return memref::LoadOp::create(b, loc, buffer, zero);
// Other cases must be vector atm.
// Vector case must use vector::TransferReadOp which will later lower to
@@ -172,8 +172,9 @@ struct DistributedLoadStoreHelper {
}
}
SmallVector<bool> inBounds(indices.size(), true);
- return b.create<vector::TransferReadOp>(
- loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
+ return vector::TransferReadOp::create(
+ b, loc, cast<VectorType>(type), buffer, indices,
+ /*padding=*/std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
@@ -243,11 +244,11 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern {
rewriter.setInsertionPoint(warpOp);
// Step 1: Create scf.if op.
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value isLane0 = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
- auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
- /*withElseRegion=*/false);
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value isLane0 = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
+ auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
+ /*withElseRegion=*/false);
rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
// Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
@@ -325,7 +326,7 @@ struct WarpOpToScfIfPattern : public WarpDistributionPattern {
// Step 7. Delete terminator and add empty scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.setInsertionPointToEnd(ifOp.thenBlock());
- rewriter.create<scf::YieldOp>(yieldLoc);
+ scf::YieldOp::create(rewriter, yieldLoc);
// Compute replacements for WarpOp results.
rewriter.replaceOp(warpOp, replacements);
@@ -512,8 +513,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
// Create a second warp op that contains only writeOp.
- auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
- loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
+ auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
+ newWarpOp.getLaneid(),
+ newWarpOp.getWarpSize());
Block &body = secondWarpOp.getBodyRegion().front();
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
@@ -521,7 +523,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
newWriteOp.getValueToStoreMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
rewriter.eraseOp(writeOp);
- rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
+ gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
return success();
}
@@ -698,7 +700,7 @@ struct WarpOpConstant : public WarpDistributionPattern {
cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
- Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
+ Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
rewriter.finalizeOpModification(warpOp);
return success();
@@ -823,9 +825,9 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
Value newMask =
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
: Value();
- auto newRead = rewriter.create<vector::TransferReadOp>(
- read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
- read.getPermutationMapAttr(), newPadding, newMask,
+ auto newRead = vector::TransferReadOp::create(
+ rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
+ newIndices, read.getPermutationMapAttr(), newPadding, newMask,
read.getInBoundsAttr());
rewriter.replaceAllUsesWith(distributedVal, newRead);
@@ -965,8 +967,8 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- Value broadcasted = rewriter.create<vector::BroadcastOp>(
- loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
+ Value broadcasted = vector::BroadcastOp::create(
+ rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
return success();
@@ -1008,8 +1010,8 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
- Value newCast = rewriter.create<vector::ShapeCastOp>(
- oldCastOp.getLoc(), castResultType,
+ Value newCast = vector::ShapeCastOp::create(
+ rewriter, oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
@@ -1091,7 +1093,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
auto newMask =
- rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
+ vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
rewriter.finalizeOpModification(warpOp);
return success();
@@ -1182,9 +1184,10 @@ struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
// Create a new insert strided slice op that inserts distributed source into
// distributed dest.
- Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
- insertOp.getLoc(), distributedDest.getType(), distributedSource,
- distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+ Value newInsert = vector::InsertStridedSliceOp::create(
+ rewriter, insertOp.getLoc(), distributedDest.getType(),
+ distributedSource, distributedDest, insertOp.getOffsets(),
+ insertOp.getStrides());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
return success();
}
@@ -1277,8 +1280,8 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
// Create a new extract strided slice op that extracts from the
// distributed vector.
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
- Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
- extractOp.getLoc(), distributedType, distributedVec,
+ Value newExtract = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, distributedVec,
extractOp.getOffsets(),
ArrayAttr::get(rewriter.getContext(), distributedSizes),
extractOp.getStrides());
@@ -1323,8 +1326,8 @@ struct WarpOpExtract : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
- Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, distributedVec, extractOp.getMixedPosition());
+ Value newExtract = vector::ExtractOp::create(
+ rewriter, loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1352,8 +1355,8 @@ struct WarpOpExtract : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
- Value newExtract = rewriter.create<vector::ExtractOp>(
- loc, distributedVec, extractOp.getMixedPosition());
+ Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
+ extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1422,7 +1425,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
Value newExtract;
SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
newExtract =
- rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
+ vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
@@ -1442,11 +1445,11 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
// Extract at position: pos % elementsPerLane
Value newPos =
elementsPerLane == 1
- ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
+ ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
: affine::makeComposedAffineApply(rewriter, loc,
sym0 % elementsPerLane, pos);
Value extracted =
- rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
+ vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
@@ -1514,8 +1517,8 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
if (pos) {
indices.push_back(pos);
}
- newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
- distributedVec, indices);
+ newInsert = vector::InsertOp::create(rewriter, loc, newSource,
+ distributedVec, indices);
// Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
@@ -1531,21 +1534,22 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
// Insert position: pos % elementsPerLane
OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 % elementsPerLane, pos);
- Value isInsertingLane = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
+ Value isInsertingLane =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ newWarpOp.getLaneid(), insertingLane);
Value newResult =
rewriter
.create<scf::IfOp>(
loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
- Value newInsert = builder.create<vector::InsertOp>(
- loc, newSource, distributedVec, newPos);
- builder.create<scf::YieldOp>(loc, newInsert);
+ Value newInsert = vector::InsertOp::create(
+ builder, loc, newSource, distributedVec, newPos);
+ scf::YieldOp::create(builder, loc, newInsert);
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, distributedVec);
+ scf::YieldOp::create(builder, loc, distributedVec);
})
.getResult(0);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
@@ -1582,8 +1586,9 @@ struct WarpOpInsert : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
- Value newResult = rewriter.create<vector::InsertOp>(
- loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
+ Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
+ distributedDest,
+ insertOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newResult);
return success();
@@ -1632,27 +1637,29 @@ struct WarpOpInsert : public WarpDistributionPattern {
Value newResult;
if (distrSrcDim >= 0) {
// Every lane inserts a small piece.
- newResult = rewriter.create<vector::InsertOp>(
- loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
+ newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
+ distributedDest,
+ insertOp.getMixedPosition());
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
SmallVector<int64_t> newPos = getAsIntegers(pos);
// tid of inserting lane: pos / elementsPerLane
- Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
- loc, newPos[distrDestDim] / elementsPerLane);
- Value isInsertingLane = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
+ Value insertingLane = arith::ConstantIndexOp::create(
+ rewriter, loc, newPos[distrDestDim] / elementsPerLane);
+ Value isInsertingLane =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
+ newWarpOp.getLaneid(), insertingLane);
// Insert position: pos % elementsPerLane
newPos[distrDestDim] %= elementsPerLane;
auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
- Value newInsert = builder.create<vector::InsertOp>(
- loc, distributedSrc, distributedDest, newPos);
- builder.create<scf::YieldOp>(loc, newInsert);
+ Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
+ distributedDest, newPos);
+ scf::YieldOp::create(builder, loc, newInsert);
};
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, distributedDest);
+ scf::YieldOp::create(builder, loc, distributedDest);
};
newResult = rewriter
.create<scf::IfOp>(loc, isInsertingLane,
@@ -1820,8 +1827,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
- auto newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ auto newForOp = scf::ForOp::create(
+ rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newForOpOperands);
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
@@ -1845,9 +1852,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueInputTypes[i - escapingValuesStartIdx]);
}
// Create the inner `WarpOp` with the new input values and types.
- auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
- newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
- newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
+ auto innerWarp = WarpExecuteOnLane0Op::create(
+ rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
+ newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
+ innerWarpInputType);
// Inline the `ForOp` body into the inner `WarpOp` body.
SmallVector<Value> argMapping;
@@ -1866,12 +1874,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
// original `ForOp` results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
- rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
+ gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
// Insert a scf.yield op at the end of the new `ForOp` body that yields
// the inner `WarpOp` results.
if (!innerWarp.getResults().empty())
- rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+ scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
// Update the users of original `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 067d4e3491391..73388a5da3e4f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -77,8 +77,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
Location loc = extractOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, extractOp.getVector(), splatZero(dropCount));
+ Value newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, extractOp.getVector(), splatZero(dropCount));
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
@@ -89,8 +89,9 @@ struct CastAwayExtractStridedSliceLeadingOneDim
auto newStrides = rewriter.getArrayAttr(
extractOp.getStrides().getValue().drop_front(dropCount));
- auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+ auto newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
+ newStrides);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
newExtractOp);
@@ -120,18 +121,19 @@ struct CastAwayInsertStridedSliceLeadingOneDim
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
- Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getValueToStore(), splatZero(srcDropCount));
- Value newDstVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getDest(), splatZero(dstDropCount));
+ Value newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+ Value newDstVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
auto newOffsets = rewriter.getArrayAttr(
insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
auto newStrides = rewriter.getArrayAttr(
insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
- auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
- loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+ auto newInsertOp = vector::InsertStridedSliceOp::create(
+ rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
+ newStrides);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
@@ -169,11 +171,11 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Value newSrcVector = insertOp.getValueToStore();
if (oldSrcRank != 0) {
- newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getValueToStore(), splatZero(srcDropCount));
+ newSrcVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount));
}
- Value newDstVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getDest(), splatZero(dstDropCount));
+ Value newDstVector = vector::ExtractOp::create(
+ rewriter, loc, insertOp.getDest(), splatZero(dstDropCount));
// New position rank needs to be computed in two steps: (1) if destination
// type has leading unit dims, we also trim the position array accordingly,
@@ -187,8 +189,8 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
newPosition.resize(newDstType.getRank() - newSrcRank,
rewriter.getI64IntegerAttr(0));
- auto newInsertOp = rewriter.create<vector::InsertOp>(
- loc, newSrcVector, newDstVector, newPosition);
+ auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
+ newDstVector, newPosition);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
@@ -209,9 +211,9 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
BroadcastableToResult::Success) {
int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
- return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
+ return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim));
}
- return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
+ return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
}
// Turns vector.transfer_read on vector with leading 1 dimensions into
@@ -259,8 +261,8 @@ struct CastAwayTransferReadLeadingOneDim
newType, newMap, maskType);
}
- auto newRead = rewriter.create<vector::TransferReadOp>(
- read.getLoc(), newType, read.getBase(), read.getIndices(),
+ auto newRead = vector::TransferReadOp::create(
+ rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
@@ -306,8 +308,8 @@ struct CastAwayTransferWriteLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
write.getInBoundsAttr().getValue().take_back(newType.getRank()));
- auto newVector = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.getVector(), splatZero(dropDim));
+ auto newVector = vector::ExtractOp::create(
+ rewriter, write.getLoc(), write.getVector(), splatZero(dropDim));
if (write.getMask()) {
VectorType maskType = write.getMaskType();
@@ -443,22 +445,23 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
contractOp.getContext()));
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
- newOperands.push_back(
- validExtract ? rewriter.create<vector::ExtractOp>(
- loc, operands[it.index()], splatZero(dropDim))
- : operands[it.index()]);
+ newOperands.push_back(validExtract
+ ? vector::ExtractOp::create(rewriter, loc,
+ operands[it.index()],
+ splatZero(dropDim))
+ : operands[it.index()]);
}
// Depending on whether this vector.contract is masked, the replacing Op
// should either be a new vector.contract Op or vector.mask Op.
- Operation *newOp = rewriter.create<vector::ContractionOp>(
- loc, newOperands[0], newOperands[1], newOperands[2],
+ Operation *newOp = vector::ContractionOp::create(
+ rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
if (maskingOp) {
- auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
- splatZero(dropDim));
+ auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
+ splatZero(dropDim));
newOp = mlir::vector::maskOperation(rewriter, newOp, newMask);
}
@@ -519,8 +522,8 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
- newOperands.push_back(rewriter.create<vector::ExtractOp>(
- op->getLoc(), operand, splatZero(dropDim)));
+ newOperands.push_back(vector::ExtractOp::create(
+ rewriter, op->getLoc(), operand, splatZero(dropDim)));
} else {
newOperands.push_back(operand);
}
@@ -559,8 +562,8 @@ struct CastAwayConstantMaskLeadingOneDim
SmallVector<int64_t> newDimSizes = {flatLeadingSize};
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
- auto newMask = rewriter.create<vector::ConstantMaskOp>(
- mask.getLoc(), newType, newDimSizes);
+ auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
+ newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index 8cc7008d80b3e..cb3e8dc67a1ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -65,26 +65,27 @@ struct VectorMaskedLoadOpConverter final
Value base = maskedLoadOp.getBase();
Value iValue = maskedLoadOp.getPassThru();
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
- Value one = rewriter.create<arith::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, 1));
+ Value one = arith::ConstantOp::create(rewriter, loc, indexType,
+ IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
- auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+ auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
- auto ifOp = rewriter.create<scf::IfOp>(
- loc, maskBit,
+ auto ifOp = scf::IfOp::create(
+ rewriter, loc, maskBit,
[&](OpBuilder &builder, Location loc) {
auto loadedValue =
- builder.create<memref::LoadOp>(loc, base, indices);
+ memref::LoadOp::create(builder, loc, base, indices);
auto combinedValue =
- builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
- builder.create<scf::YieldOp>(loc, combinedValue.getResult());
+ vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
+ scf::YieldOp::create(builder, loc, combinedValue.getResult());
},
[&](OpBuilder &builder, Location loc) {
- builder.create<scf::YieldOp>(loc, iValue);
+ scf::YieldOp::create(builder, loc, iValue);
});
iValue = ifOp.getResult(0);
- indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ indices.back() =
+ arith::AddIOp::create(rewriter, loc, indices.back(), one);
}
rewriter.replaceOp(maskedLoadOp, iValue);
@@ -132,18 +133,19 @@ struct VectorMaskedStoreOpConverter final
Value base = maskedStoreOp.getBase();
Value value = maskedStoreOp.getValueToStore();
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
- Value one = rewriter.create<arith::ConstantOp>(
- loc, indexType, IntegerAttr::get(indexType, 1));
+ Value one = arith::ConstantOp::create(rewriter, loc, indexType,
+ IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
- auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
+ auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
- auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
+ auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
- rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
+ auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
+ memref::StoreOp::create(rewriter, loc, extractedValue, base, indices);
rewriter.setInsertionPointAfter(ifOp);
- indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
+ indices.back() =
+ arith::AddIOp::create(rewriter, loc, indices.back(), one);
}
rewriter.eraseOp(maskedStoreOp);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0fe08417f818f..e6bb96f453fbc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -132,8 +132,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
SmallVector<Value> newMaskOperands(maskOperands.drop_back());
newMaskOperands.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
- return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
- newMaskOperands);
+ return vector::CreateMaskOp::create(rewriter, loc, newMaskType,
+ newMaskOperands);
})
.Case<vector::ConstantMaskOp>(
[&](auto constantMaskOp) -> std::optional<Operation *> {
@@ -143,8 +143,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int64_t &maskIndex = maskDimSizes.back();
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
numSrcElemsPerDest);
- return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- maskDimSizes);
+ return vector::ConstantMaskOp::create(
+ rewriter, loc, newMaskType, maskDimSizes);
})
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
@@ -182,16 +182,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
}
compressedMaskValues.push_back(combinedValue);
}
- return rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
+ return arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(newMaskType, compressedMaskValues));
});
if (!newMask)
return failure();
while (!extractOps.empty()) {
- newMask = rewriter.create<vector::ExtractOp>(
- loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
+ newMask =
+ vector::ExtractOp::create(rewriter, loc, (*newMask)->getResults()[0],
+ extractOps.back().getMixedPosition());
extractOps.pop_back();
}
@@ -258,8 +260,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
- return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
- dest, offsets, strides);
+ return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
+ dest, offsets, strides);
}
/// Extracts 1-D subvector from a 1-D vector.
@@ -301,11 +303,12 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
for (int i = 0; i < numElemsToExtract; ++i) {
Value extractLoc =
(i == 0) ? dyn_cast<Value>(offset)
- : rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), dyn_cast<Value>(offset),
- rewriter.create<arith::ConstantIndexOp>(loc, i));
- auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
- dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
+ : arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(),
+ dyn_cast<Value>(offset),
+ arith::ConstantIndexOp::create(rewriter, loc, i));
+ auto extractOp = vector::ExtractOp::create(rewriter, loc, src, extractLoc);
+ dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, i);
}
return dest;
}
@@ -344,13 +347,13 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
for (int64_t i = 0; i < numElemsToInsert; ++i) {
- auto insertLoc = i == 0
- ? destOffsetVal
- : rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), destOffsetVal,
- rewriter.create<arith::ConstantIndexOp>(loc, i));
- auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
- dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+ auto insertLoc =
+ i == 0 ? destOffsetVal
+ : arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), destOffsetVal,
+ arith::ConstantIndexOp::create(rewriter, loc, i));
+ auto extractOp = vector::ExtractOp::create(rewriter, loc, src, i);
+ dest = vector::InsertOp::create(rewriter, loc, extractOp, dest, insertLoc);
}
return dest;
}
@@ -369,11 +372,11 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
Type containerElemTy) {
auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
emulatedElemTy.getIntOrFloatBitWidth();
- auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
- return rewriter.create<vector::BitCastOp>(
- loc,
+ auto newLoad = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(numContainerElemsToLoad, containerElemTy),
+ base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return vector::BitCastOp::create(
+ rewriter, loc,
VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
emulatedElemTy),
newLoad);
@@ -390,16 +393,17 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
"expected input and output number of bits to match");
if (trueValue.getType() != downcastType) {
- trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+ trueValue =
+ vector::BitCastOp::create(builder, loc, downcastType, trueValue);
}
if (falseValue.getType() != downcastType) {
falseValue =
- builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+ vector::BitCastOp::create(builder, loc, downcastType, falseValue);
}
Value selectedType =
- builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+ arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
// Upcast the selected value to the new type.
- return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
+ return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
}
/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
@@ -422,8 +426,8 @@ static void atomicRMW(OpBuilder &builder, Location loc,
// Create an atomic load-modify-write region using
// `memref.generic_atomic_rmw`.
- auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
- loc, linearizedMemref, ValueRange{storeIdx});
+ auto atomicOp = memref::GenericAtomicRMWOp::create(
+ builder, loc, linearizedMemref, ValueRange{storeIdx});
Value origValue = atomicOp.getCurrentValue();
OpBuilder::InsertionGuard guard(builder);
@@ -432,16 +436,16 @@ static void atomicRMW(OpBuilder &builder, Location loc,
// Load the original value from memory, and cast it to the original element
// type.
auto oneElemVecType = VectorType::get({1}, origValue.getType());
- Value origVecValue = builder.create<vector::FromElementsOp>(
- loc, oneElemVecType, ValueRange{origValue});
+ Value origVecValue = vector::FromElementsOp::create(
+ builder, loc, oneElemVecType, ValueRange{origValue});
// Construct the final masked value and yield it.
Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
auto scalarMaskedValue =
- builder.create<vector::ExtractOp>(loc, maskedValue, 0);
- builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+ vector::ExtractOp::create(builder, loc, maskedValue, 0);
+ memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
}
/// Generate a non-atomic read-modify-write sequence for storing to the emulated
@@ -453,16 +457,17 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc,
auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
- Value origVecValue = builder.create<vector::LoadOp>(
- loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
- origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
- origVecValue);
+ Value origVecValue =
+ vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
+ ValueRange{linearizedIndex});
+ origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
+ origVecValue);
Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
- builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
- linearizedIndex);
+ vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
+ linearizedIndex);
}
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
@@ -489,8 +494,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
"vector element must be a valid sub-byte type");
auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
- auto emptyByteVector = rewriter.create<arith::ConstantOp>(
- loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
+ auto emptyByteVector = arith::ConstantOp::create(
+ rewriter, loc,
+ VectorType::get({emulatedPerContainerElem}, vectorElementType),
rewriter.getZeroAttr(
VectorType::get({emulatedPerContainerElem}, vectorElementType)));
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
@@ -602,7 +608,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
// FIXME: ATM, we do not test cases where offsets, sizes, or strides are
// non-zero. As such, this is not needed.
@@ -664,8 +670,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, containerElemTy),
+ auto bitCast = vector::BitCastOp::create(
+ rewriter, loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), memrefBase,
@@ -732,8 +738,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
*foldedNumFrontPadElems, true);
}
- auto frontMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+ auto frontMask = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
@@ -751,9 +758,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Increment the destination index by 1 to align to the emulated width
// boundary.
- auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- currentDestIndex = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+ auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ currentDestIndex = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
// 2. Full width store for the inner output bytes.
// After the previous step, the store address is aligned to the emulated
@@ -772,15 +779,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto storeType = VectorType::get(
{originType.getNumElements() / emulatedPerContainerElem},
memrefElemType);
- auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
- fullWidthStorePart);
- rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
- currentDestIndex);
+ auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
+ fullWidthStorePart);
+ vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
+ currentDestIndex);
currentSourceIndex += numNonFullWidthElements;
- currentDestIndex = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), currentDestIndex,
- rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+ currentDestIndex = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), currentDestIndex,
+ arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
}
// 3. Partial width store for the trailing output byte.
@@ -795,8 +802,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Generate back mask.
auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
- auto backMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+ auto backMask = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
@@ -848,7 +856,7 @@ struct ConvertVectorMaskedStore final
return failure();
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
OpFoldResult linearizedIndicesOfr;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
@@ -901,21 +909,21 @@ struct ConvertVectorMaskedStore final
auto numElements = (origElements + emulatedPerContainerElem - 1) /
emulatedPerContainerElem;
auto newType = VectorType::get(numElements, containerElemTy);
- auto passThru = rewriter.create<arith::ConstantOp>(
- loc, newType, rewriter.getZeroAttr(newType));
+ auto passThru = arith::ConstantOp::create(rewriter, loc, newType,
+ rewriter.getZeroAttr(newType));
- auto newLoad = rewriter.create<vector::MaskedLoadOp>(
- loc, newType, adaptor.getBase(), linearizedIndices,
+ auto newLoad = vector::MaskedLoadOp::create(
+ rewriter, loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);
auto newBitCastType =
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
Value valueToStore =
- rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
- valueToStore = rewriter.create<arith::SelectOp>(
- loc, op.getMask(), op.getValueToStore(), valueToStore);
+ vector::BitCastOp::create(rewriter, loc, newBitCastType, newLoad);
+ valueToStore = arith::SelectOp::create(rewriter, loc, op.getMask(),
+ op.getValueToStore(), valueToStore);
valueToStore =
- rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
+ vector::BitCastOp::create(rewriter, loc, newType, valueToStore);
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
@@ -990,7 +998,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
@@ -1016,8 +1024,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
numElements, emulatedElemTy, containerElemTy);
if (!foldedIntraVectorOffset) {
- auto resultVector = rewriter.create<arith::ConstantOp>(
- loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ auto resultVector = arith::ConstantOp::create(
+ rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
@@ -1111,7 +1119,7 @@ struct ConvertVectorMaskedLoad final
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
@@ -1142,8 +1150,8 @@ struct ConvertVectorMaskedLoad final
auto newBitcastType =
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+ auto emptyVector = arith::ConstantOp::create(
+ rewriter, loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
@@ -1153,25 +1161,26 @@ struct ConvertVectorMaskedLoad final
*foldedIntraVectorOffset);
}
auto newPassThru =
- rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
+ vector::BitCastOp::create(rewriter, loc, loadType, passthru);
// Generating the new masked load.
- auto newLoad = rewriter.create<vector::MaskedLoadOp>(
- loc, loadType, adaptor.getBase(),
+ auto newLoad = vector::MaskedLoadOp::create(
+ rewriter, loc, loadType, adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newMask.value()->getResult(0), newPassThru);
// Setting the part that originally was not effectively loaded from memory
// to pass through.
auto bitCast =
- rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
+ vector::BitCastOp::create(rewriter, loc, newBitcastType, newLoad);
Value mask = op.getMask();
auto newSelectMaskType = VectorType::get(
numElements * emulatedPerContainerElem, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
- auto emptyMask = rewriter.create<arith::ConstantOp>(
- loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+ auto emptyMask =
+ arith::ConstantOp::create(rewriter, loc, newSelectMaskType,
+ rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
@@ -1182,7 +1191,7 @@ struct ConvertVectorMaskedLoad final
}
Value result =
- rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
+ arith::SelectOp::create(rewriter, loc, mask, bitCast, passthru);
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
@@ -1272,17 +1281,17 @@ struct ConvertVectorTransferRead final
// thus their values don't matter.
Value padding = adaptor.getPadding();
if (!padding.getType().isInteger()) {
- padding = rewriter.create<arith::BitcastOp>(
- loc,
+ padding = arith::BitcastOp::create(
+ rewriter, loc,
IntegerType::get(rewriter.getContext(),
padding.getType().getIntOrFloatBitWidth()),
padding);
}
auto newPadding =
- rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
+ arith::ExtUIOp::create(rewriter, loc, containerElemTy, padding);
auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
@@ -1303,20 +1312,21 @@ struct ConvertVectorTransferRead final
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
emulatedPerContainerElem);
- auto newRead = rewriter.create<vector::TransferReadOp>(
- loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),
+ auto newRead = vector::TransferReadOp::create(
+ rewriter, loc, VectorType::get(numElements, containerElemTy),
+ adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc,
+ auto bitCast = vector::BitCastOp::create(
+ rewriter, loc,
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
newRead);
Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
- auto zeros = rewriter.create<arith::ConstantOp>(
- loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ auto zeros = arith::ConstantOp::create(
+ rewriter, loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
@@ -1689,32 +1699,33 @@ Value BitCastRewriter::genericRewriteStep(
PatternRewriter &rewriter, Location loc, Value initialValue,
Value runningResult, const BitCastRewriter::Metadata &metadata) {
// Create vector.shuffle from the metadata.
- auto shuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, initialValue, initialValue, metadata.shuffles);
+ auto shuffleOp = vector::ShuffleOp::create(rewriter, loc, initialValue,
+ initialValue, metadata.shuffles);
// Intersect with the mask.
VectorType shuffledVectorType = shuffleOp.getResultVectorType();
- auto constOp = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
- Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+ auto constOp = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+ Value andValue = arith::AndIOp::create(rewriter, loc, shuffleOp, constOp);
// Align right on 0.
- auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
- loc,
+ auto shiftRightConstantOp = arith::ConstantOp::create(
+ rewriter, loc,
DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
Value shiftedRight =
- rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+ arith::ShRUIOp::create(rewriter, loc, andValue, shiftRightConstantOp);
// Shift bits left into their final position.
- auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
- loc,
+ auto shiftLeftConstantOp = arith::ConstantOp::create(
+ rewriter, loc,
DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
Value shiftedLeft =
- rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+ arith::ShLIOp::create(rewriter, loc, shiftedRight, shiftLeftConstantOp);
runningResult =
runningResult
- ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
+ ? arith::OrIOp::create(rewriter, loc, runningResult, shiftedLeft)
: shiftedLeft;
return runningResult;
@@ -1737,7 +1748,7 @@ static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
// Adjust last dimension of the vector, so the total size remains the same.
vecShape.back() = vecShape.back() / numSrcElemsPerByte;
auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
- return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
+ return vector::BitCastOp::create(rewriter, loc, i8VecType, subByteVec);
}
/// Extracts a signed N-bit sequence from each element of a vector of bytes,
@@ -1765,15 +1776,15 @@ static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
if (bitsToShiftLeft != 0) {
- Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
- shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+ Value shiftLeftValues = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+ shl = arith::ShLIOp::create(rewriter, loc, src, shiftLeftValues);
}
int8_t bitsToShiftRight = 8 - numBits;
- Value shiftRightValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
- Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+ Value shiftRightValues = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ Value shr = arith::ShRSIOp::create(rewriter, loc, shl, shiftRightValues);
return shr;
}
@@ -1807,17 +1818,17 @@ static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
int8_t bitsToShiftRight = bitIdx;
Value shr = src;
if (bitsToShiftRight != 0) {
- Value shiftRightValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
- shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
+ Value shiftRightValues = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ shr = arith::ShRUIOp::create(rewriter, loc, src, shiftRightValues);
}
if (bitIdx + numBits == 8) {
return shr;
}
uint8_t lowBitsMask = (1 << numBits) - 1;
- Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(srcType, lowBitsMask));
- return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
+ Value lowBitsMaskValues = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(srcType, lowBitsMask));
+ return arith::AndIOp::create(rewriter, loc, shr, lowBitsMaskValues);
}
using ExtractNBitsFn =
@@ -1840,7 +1851,7 @@ static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
Value high = extFn(rewriter, loc, i8Vector, 4, 4);
// 3. Interleave low and high i8 elements.
- return rewriter.create<vector::InterleaveOp>(loc, low, high);
+ return vector::InterleaveOp::create(rewriter, loc, low, high);
}
/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
@@ -1873,9 +1884,10 @@ static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
// 02 = [0,2,0,2,0,2,0,2],...
// 13 = [1,3,1,3,1,3,1,3],...
// 0213 = [0,1,2,3,...],...
- Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
- Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
- return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+ Value interleave02 = vector::InterleaveOp::create(rewriter, loc, vec0, vec2);
+ Value interleave13 = vector::InterleaveOp::create(rewriter, loc, vec1, vec3);
+ return vector::InterleaveOp::create(rewriter, loc, interleave02,
+ interleave13);
}
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
@@ -1887,29 +1899,29 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
"Expected i8 type");
// 1. De-interleave low and high i8 elements.
- auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue);
+ auto deinterleaveOp = vector::DeinterleaveOp::create(rewriter, loc, srcValue);
// 2. Zero out the upper side of each low i8 element.
constexpr int8_t i8LowBitMask = 0x0F;
VectorType deinterI8VecType = deinterleaveOp.getResultVectorType();
- Value zeroOutMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
- Value zeroOutLow = rewriter.create<arith::AndIOp>(
- loc, deinterleaveOp.getRes1(), zeroOutMask);
+ Value zeroOutMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask));
+ Value zeroOutLow = arith::AndIOp::create(
+ rewriter, loc, deinterleaveOp.getRes1(), zeroOutMask);
// 3. Move high i4 values to upper side of the byte.
constexpr int8_t bitsToShift = 4;
- auto shiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
- Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(),
- shiftValues);
+ auto shiftValues = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+ Value shlHigh = arith::ShLIOp::create(rewriter, loc, deinterleaveOp.getRes2(),
+ shiftValues);
// 4. Merge high and low i4 values.
- auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
+ auto mergedHiLowOp = arith::OrIOp::create(rewriter, loc, zeroOutLow, shlHigh);
// 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
- return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+ return vector::BitCastOp::create(rewriter, loc, i4VecType, mergedHiLowOp);
}
namespace {
@@ -2151,7 +2163,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
Value i8TruncVal =
- rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+ arith::TruncIOp::create(rewriter, loc, i8VecType, srcValue);
// Rewrite the i8 -> i4 truncation part.
Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
@@ -2199,10 +2211,10 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
// support is available.
auto srcNativeVecType = srcSubByteVecType.cloneWith(
std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
- Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
- transposeOp.getVector());
- Value newTranspose = rewriter.create<vector::TransposeOp>(
- loc, extOp, transposeOp.getPermutation());
+ Value extOp = arith::ExtSIOp::create(rewriter, loc, srcNativeVecType,
+ transposeOp.getVector());
+ Value newTranspose = vector::TransposeOp::create(
+ rewriter, loc, extOp, transposeOp.getPermutation());
VectorType dstSubByteVecType = transposeOp.getResultVectorType();
rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
newTranspose);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index d834a99076834..72352d72bfe77 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -53,15 +53,15 @@ class DecomposeDifferentRankInsertStridedSlice
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
- Value extracted = rewriter.create<ExtractOp>(
- loc, op.getDest(),
- getI64SubArray(op.getOffsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
+ Value extracted =
+ ExtractOp::create(rewriter, loc, op.getDest(),
+ getI64SubArray(op.getOffsets(), /*dropFront=*/0,
+ /*dropBack=*/rankRest));
// A
diff erent pattern will kick in for InsertStridedSlice with matching
// ranks.
- auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.getValueToStore(), extracted,
+ auto stridedSliceInnerOp = InsertStridedSliceOp::create(
+ rewriter, loc, op.getValueToStore(), extracted,
getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.getStrides(), /*dropFront=*/0));
@@ -131,8 +131,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
- Value scaledSource = rewriter.create<ShuffleOp>(
- loc, op.getValueToStore(), op.getValueToStore(), offsets);
+ Value scaledSource = ShuffleOp::create(
+ rewriter, loc, op.getValueToStore(), op.getValueToStore(), offsets);
// 2. Create a mask where we take the value from scaledSource of dest
// depending on the offset.
@@ -156,21 +156,21 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource =
- rewriter.create<ExtractOp>(loc, op.getValueToStore(), idx);
+ ExtractOp::create(rewriter, loc, op.getValueToStore(), idx);
if (isa<VectorType>(extractedSource.getType())) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest =
- rewriter.create<ExtractOp>(loc, op.getDest(), off);
+ ExtractOp::create(rewriter, loc, op.getDest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
- extractedSource = rewriter.create<InsertStridedSliceOp>(
- loc, extractedSource, extractedDest,
+ extractedSource = InsertStridedSliceOp::create(
+ rewriter, loc, extractedSource, extractedDest,
getI64SubArray(op.getOffsets(), /* dropFront=*/1),
getI64SubArray(op.getStrides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
- res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
+ res = InsertOp::create(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
@@ -250,12 +250,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
SmallVector<Value> elements;
elements.reserve(size);
for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
- elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
+ elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i));
- Value result = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(op.getType()));
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(op.getType()));
for (int64_t i = 0; i < size; ++i)
- result = rewriter.create<InsertOp>(loc, elements[i], result, i);
+ result = InsertOp::create(rewriter, loc, elements[i], result, i);
rewriter.replaceOp(op, result);
return success();
@@ -301,17 +301,17 @@ class DecomposeNDExtractStridedSlice
return failure();
// Extract/insert on a lower ranked extract strided slice op.
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value res = SplatOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
- Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
- Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
+ Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
+ Value extracted = ExtractStridedSliceOp::create(
+ rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
getI64SubArray(op.getSizes(), /* dropFront=*/1),
getI64SubArray(op.getStrides(), /* dropFront=*/1));
- res = rewriter.create<InsertOp>(loc, extracted, res, idx);
+ res = InsertOp::create(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fe17b3c0b2cfc..491b448e9e1e9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -658,7 +658,7 @@ struct LinearizeVectorCreateMask final
// The result of the comparison is then multiplied with
// the second operand of create_mask to get the 1D mask.
auto firstOperand = adaptor.getOperands().front();
- auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
@@ -668,7 +668,7 @@ struct LinearizeVectorCreateMask final
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
auto newMask =
- rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+ mlir::vector::CreateMaskOp::create(rewriter, loc, dstTy, maskSize);
rewriter.replaceOp(createMaskOp, newMask);
return success();
}
@@ -710,8 +710,9 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
auto linearTy = typeConverter->convertType<VectorType>(vecTy);
- auto newLoad = rewriter.create<vector::LoadOp>(
- loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto newLoad =
+ vector::LoadOp::create(rewriter, loadOp.getLoc(), linearTy,
+ adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOp(loadOp, newLoad.getResult());
return success();
}
@@ -832,7 +833,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
return nullptr;
- return builder.create<vector::ShapeCastOp>(loc, type, value);
+ return vector::ShapeCastOp::create(builder, loc, type, value);
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
index a7403250a069b..8a181a429e41c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -82,8 +82,8 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
// Replace createMaskOp with an all-true constant. This should result in the
// mask being removed in most cases (as xfer ops + vector.mask have folds to
// remove all-true masks).
- auto allTrue = rewriter.create<vector::ConstantMaskOp>(
- createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
+ auto allTrue = vector::ConstantMaskOp::create(
+ rewriter, createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
rewriter.replaceAllUsesWith(createMaskOp, allTrue);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c20a1b355996c..2676d254c9b64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -286,8 +286,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
if (resultType.canonicalizeStridedLayout() ==
inputType.canonicalizeStridedLayout())
return input;
- return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
- sizes, strides);
+ return memref::SubViewOp::create(rewriter, loc, resultType, input, offsets,
+ sizes, strides);
}
/// Returns the number of dims that aren't unit dims.
@@ -395,13 +395,13 @@ class TransferReadDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
- transferReadOp.getPadding(), maskOp,
+ Operation *newTransferReadOp = vector::TransferReadOp::create(
+ rewriter, loc, reducedVectorType, reducedShapeSource, zeros,
+ identityMap, transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
if (maskingOp) {
@@ -477,15 +477,15 @@ class TransferWriteDropUnitDimsPattern
}
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
- Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
- loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
- maskOp, rewriter.getBoolArrayAttr(inBounds));
+ Operation *newXferWrite = vector::TransferWriteOp::create(
+ rewriter, loc, Type(), shapeCastSrc, reducedShapeSource, zeros,
+ identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -520,7 +520,7 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
collapsedIndices.push_back(i);
reassociation.push_back(collapsedIndices);
- return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
+ return memref::CollapseShapeOp::create(rewriter, loc, input, reassociation);
}
/// Returns the new indices that collapses the inner dimensions starting from
@@ -559,7 +559,7 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
// one would get the following offset:
// %offset = %arg0 * 43
OpFoldResult collapsedOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
auto collapsedStrides = computeSuffixProduct(
ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
@@ -573,8 +573,8 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
if (auto value = dyn_cast<Value>(collapsedOffset)) {
indicesAfterCollapsing.push_back(value);
} else {
- indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, *getConstantIntValue(collapsedOffset)));
+ indicesAfterCollapsing.push_back(arith::ConstantIndexOp::create(
+ rewriter, loc, *getConstantIntValue(collapsedOffset)));
}
return indicesAfterCollapsing;
@@ -659,8 +659,8 @@ class FlattenContiguousRowMajorTransferReadPattern
// 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
- vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
- loc, flatVectorType, collapsedSource, collapsedIndices,
+ vector::TransferReadOp flatRead = vector::TransferReadOp::create(
+ rewriter, loc, flatVectorType, collapsedSource, collapsedIndices,
transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
@@ -757,10 +757,10 @@ class FlattenContiguousRowMajorTransferWritePattern
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
- rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
- vector::TransferWriteOp flatWrite =
- rewriter.create<vector::TransferWriteOp>(
- loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
+ vector::ShapeCastOp::create(rewriter, loc, flatVectorType, vector);
+ vector::TransferWriteOp flatWrite = vector::TransferWriteOp::create(
+ rewriter, loc, flatVector, collapsedSource, collapsedIndices,
+ collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_write with the new one writing the
@@ -846,8 +846,8 @@ class RewriteScalarExtractOfTransferRead
if (auto value = dyn_cast<Value>(composedIdx)) {
newIndices[idx] = value;
} else {
- newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
- extractOp.getLoc(), *getConstantIntValue(composedIdx));
+ newIndices[idx] = arith::ConstantIndexOp::create(
+ rewriter, extractOp.getLoc(), *getConstantIntValue(composedIdx));
}
}
if (isa<MemRefType>(xferOp.getBase().getType())) {
@@ -883,8 +883,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Only float and integer element types are supported.
- Value scalar =
- rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
+ Value scalar = vector::ExtractOp::create(rewriter, xferOp.getLoc(),
+ xferOp.getVector());
// Construct a scalar store.
if (isa<MemRefType>(xferOp.getBase().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index eee090d495c17..05b00744beea2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -57,12 +57,12 @@ static Value createInBoundsCond(RewriterBase &b,
if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
return;
Value cond =
- b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
- getValueOrCreateConstantIndexOp(b, loc, sum),
- getValueOrCreateConstantIndexOp(b, loc, dimSz));
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
+ getValueOrCreateConstantIndexOp(b, loc, sum),
+ getValueOrCreateConstantIndexOp(b, loc, dimSz));
// Conjunction over all dims for which we are in-bounds.
if (inBoundsCond)
- inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
+ inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
else
inBoundsCond = cond;
});
@@ -170,11 +170,12 @@ static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
sourceType = MemRefType::get(
sourceType.getShape(), sourceType.getElementType(),
sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
- res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
+ res =
+ memref::MemorySpaceCastOp::create(b, memref.getLoc(), sourceType, res);
}
if (sourceType == compatibleMemRefType)
return res;
- return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
+ return memref::CastOp::create(b, memref.getLoc(), compatibleMemRefType, res);
}
/// Operates under a scoped context to build the intersection between the
@@ -196,16 +197,17 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
Value dimMemRef =
- b.create<memref::DimOp>(xferOp.getLoc(), xferOp.getBase(), indicesIdx);
- Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
+ memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
+ Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
Value index = xferOp.getIndices()[indicesIdx];
AffineExpr i, j, k;
bindDims(xferOp.getContext(), i, j, k);
SmallVector<AffineMap, 4> maps =
AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
// affine_min(%dimMemRef - %index, %dimAlloc)
- Value affineMin = b.create<affine::AffineMinOp>(
- loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
+ Value affineMin =
+ affine::AffineMinOp::create(b, loc, index.getType(), maps[0],
+ ValueRange{dimMemRef, index, dimAlloc});
sizes.push_back(affineMin);
});
@@ -213,10 +215,10 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
- auto copySrc = b.create<memref::SubViewOp>(
- loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
- auto copyDest = b.create<memref::SubViewOp>(
- loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
+ auto copySrc = memref::SubViewOp::create(
+ b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
+ auto copyDest = memref::SubViewOp::create(
+ b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
return std::make_pair(copySrc, copyDest);
}
@@ -244,32 +246,32 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(b, loc, 0);
Value memref = xferOp.getBase();
- return b.create<scf::IfOp>(
- loc, inBoundsCond,
+ return scf::IfOp::create(
+ b, loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
- b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
- ValueRange{alloc});
+ linalg::FillOp::create(b, loc, ValueRange{xferOp.getPadding()},
+ ValueRange{alloc});
// Take partial subview of memref which guarantees no dimension
// overflows.
IRRewriter rewriter(b);
std::pair<Value, Value> copyArgs = createSubViewIntersection(
rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
alloc);
- b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
+ memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
Value casted =
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
});
}
@@ -297,30 +299,30 @@ static scf::IfOp createFullPartialVectorTransferRead(
Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
scf::IfOp fullPartialIfOp;
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(b, loc, 0);
Value memref = xferOp.getBase();
- return b.create<scf::IfOp>(
- loc, inBoundsCond,
+ return scf::IfOp::create(
+ b, loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
Operation *newXfer = b.clone(*xferOp.getOperation());
Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
- b.create<memref::StoreOp>(
- loc, vector,
- b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, vector.getType()), alloc));
+ memref::StoreOp::create(
+ b, loc, vector,
+ vector::TypeCastOp::create(
+ b, loc, MemRefType::get({}, vector.getType()), alloc));
Value casted =
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
});
}
@@ -344,7 +346,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
- Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value zero = arith::ConstantIndexOp::create(b, loc, 0);
Value memref = xferOp.getBase();
return b
.create<scf::IfOp>(
@@ -354,7 +356,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
Value casted =
@@ -362,7 +364,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);
- b.create<scf::YieldOp>(loc, viewAndIndices);
+ scf::YieldOp::create(b, loc, viewAndIndices);
})
->getResults();
}
@@ -384,15 +386,15 @@ static void createFullPartialLinalgCopy(RewriterBase &b,
vector::TransferWriteOp xferOp,
Value inBoundsCond, Value alloc) {
Location loc = xferOp.getLoc();
- auto notInBounds = b.create<arith::XOrIOp>(
- loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
- b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+ auto notInBounds = arith::XOrIOp::create(
+ b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
+ scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
IRRewriter rewriter(b);
std::pair<Value, Value> copyArgs = createSubViewIntersection(
rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
alloc);
- b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
- b.create<scf::YieldOp>(loc, ValueRange{});
+ memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
+ scf::YieldOp::create(b, loc, ValueRange{});
});
}
@@ -413,18 +415,18 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b,
Value inBoundsCond,
Value alloc) {
Location loc = xferOp.getLoc();
- auto notInBounds = b.create<arith::XOrIOp>(
- loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
- b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
+ auto notInBounds = arith::XOrIOp::create(
+ b, loc, inBoundsCond, arith::ConstantIntOp::create(b, loc, true, 1));
+ scf::IfOp::create(b, loc, notInBounds, [&](OpBuilder &b, Location loc) {
IRMapping mapping;
- Value load = b.create<memref::LoadOp>(
- loc,
- b.create<vector::TypeCastOp>(
- loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
+ Value load = memref::LoadOp::create(
+ b, loc,
+ vector::TypeCastOp::create(
+ b, loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
ValueRange());
mapping.map(xferOp.getVector(), load);
b.clone(*xferOp.getOperation(), mapping);
- b.create<scf::YieldOp>(loc, ValueRange{});
+ scf::YieldOp::create(b, loc, ValueRange{});
});
}
@@ -554,9 +556,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
b.setInsertionPointToStart(&scope->getRegion(0).front());
auto shape = xferOp.getVectorType().getShape();
Type elementType = xferOp.getVectorType().getElementType();
- alloc = b.create<memref::AllocaOp>(scope->getLoc(),
- MemRefType::get(shape, elementType),
- ValueRange{}, b.getI64IntegerAttr(32));
+ alloc = memref::AllocaOp::create(b, scope->getLoc(),
+ MemRefType::get(shape, elementType),
+ ValueRange{}, b.getI64IntegerAttr(32));
}
MemRefType compatibleMemRefType =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index fe2707629d82e..73ca327bb49c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -381,8 +381,8 @@ FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();
- Operation *newOp = rewriter.create<vector::ContractionOp>(
- contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
+ Operation *newOp = vector::ContractionOp::create(
+ rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
// Handle the mask.
@@ -534,8 +534,8 @@ struct ReorderElementwiseOpsOnTranspose final
// This is a constant. Create a reverse transpose op for it.
auto vectorType =
srcType.clone(cast<VectorType>(operand.getType()).getElementType());
- srcValues.push_back(rewriter.create<vector::TransposeOp>(
- operand.getLoc(), vectorType, operand, invOrder));
+ srcValues.push_back(vector::TransposeOp::create(
+ rewriter, operand.getLoc(), vectorType, operand, invOrder));
}
}
@@ -608,20 +608,20 @@ struct BubbleDownVectorBitCastForExtract
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
Location loc = extractOp.getLoc();
- Value packedValue = rewriter.create<vector::ExtractOp>(
- loc, castOp.getSource(), index / expandRatio);
+ Value packedValue = vector::ExtractOp::create(
+ rewriter, loc, castOp.getSource(), index / expandRatio);
Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, packedVecType, rewriter.getZeroAttr(packedVecType));
- packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
- /*position=*/0);
+ Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
+ rewriter.getZeroAttr(packedVecType));
+ packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
+ /*position=*/0);
// Cast it to a vector with the desired scalar's type.
// E.g. f32 -> vector<2xf16>
VectorType packedType =
VectorType::get({expandRatio}, castDstType.getElementType());
Value castedValue =
- rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
+ vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
// Finally extract the desired scalar.
rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
@@ -700,9 +700,9 @@ struct BubbleDownBitCastForStridedSliceExtract
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
- auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
- newSizes, extractOp.getStrides());
+ auto newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
+ newOffsets, newSizes, extractOp.getStrides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
@@ -761,8 +761,9 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
- auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
+ auto newCastSrcOp =
+ vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
+ insertOp.getValueToStore());
SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
dstDims.back() =
@@ -771,8 +772,8 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
VectorType::get(dstDims, castDstType.getElementType());
// Bitcast the destination.
- auto newCastDstOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+ auto newCastDstOp = vector::BitCastOp::create(
+ rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
// Generate new insert.
rewriter.replaceOpWithNewOp<vector::InsertOp>(
@@ -852,8 +853,9 @@ struct BubbleUpBitCastForStridedSliceInsert
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
- auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
+ auto newCastSrcOp =
+ vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
+ insertOp.getValueToStore());
SmallVector<int64_t> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
@@ -861,8 +863,8 @@ struct BubbleUpBitCastForStridedSliceInsert
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
- auto newCastDstOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+ auto newCastDstOp = vector::BitCastOp::create(
+ rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
@@ -936,9 +938,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
Type elemType = castDstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
+ Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value res = SplatOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -947,13 +949,13 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
castDstType.getElementType());
for (int i = 0, e = shrinkRatio; i < e; ++i) {
- Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
- sliceShape, strides);
+ Value extracted = ExtractStridedSliceOp::create(
+ rewriter, loc, bitcastOp.getSource(),
+ ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
Value bitcast =
- rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
- res = rewriter.create<InsertStridedSliceOp>(
- loc, bitcast, res,
+ BitCastOp::create(rewriter, loc, newCastDstType, extracted);
+ res = InsertStridedSliceOp::create(
+ rewriter, loc, bitcast, res,
ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
}
rewriter.replaceOp(bitcastOp, res);
@@ -1103,7 +1105,7 @@ class ExtractOpFromElementwise final
Location loc = eltwise->getLoc();
SmallVector<OpFoldResult> pos = op.getMixedPosition();
for (Value arg : eltwise->getOperands()) {
- Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
+ Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
mapping.map(arg, newArg);
}
@@ -1292,19 +1294,19 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
indicesAttr = rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
}
- Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
+ Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
- indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
+ Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
- return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
- bounds);
+ vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
+ indices, bounds);
}
template <typename ConcreteOp>
@@ -1335,15 +1337,15 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
Value off = xferOp.getIndices()[lastIndex];
Value dim =
vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
- Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
- Value mask = rewriter.create<vector::CreateMaskOp>(
- loc,
+ Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off);
+ Value mask = vector::CreateMaskOp::create(
+ rewriter, loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
vtp.getScalableDims()),
b);
if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
- mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
+ mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
}
rewriter.modifyOpInPlace(xferOp, [&]() {
@@ -1548,12 +1550,13 @@ class DropInnerMostUnitDimsTransferRead
strides);
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
- Value rankedReducedView = rewriter.create<memref::SubViewOp>(
- loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides);
+ Value rankedReducedView =
+ memref::SubViewOp::create(rewriter, loc, resultMemrefType,
+ readOp.getBase(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
- Value result = rewriter.create<vector::TransferReadOp>(
- loc, resultTargetVecType, rankedReducedView,
+ Value result = vector::TransferReadOp::create(
+ rewriter, loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
readOp.getPadding(),
// TODO: support mask.
@@ -1639,8 +1642,9 @@ class DropInnerMostUnitDimsTransferWrite
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
- Value rankedReducedView = rewriter.create<memref::SubViewOp>(
- loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides);
+ Value rankedReducedView =
+ memref::SubViewOp::create(rewriter, loc, resultMemrefType,
+ writeOp.getBase(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
@@ -1708,21 +1712,21 @@ struct CanonicalizeContractMatmulToMMT final
auto createTranspose = [&rewriter, loc](Value mat) -> Value {
if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
Value trans =
- rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
+ vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
VectorType newType =
cast<VectorType>(trans.getType())
.clone(cast<VectorType>(mat.getType()).getElementType());
- return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
+ return arith::ExtSIOp::create(rewriter, loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
Value trans =
- rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
+ vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm);
VectorType newType =
VectorType::get(cast<VectorType>(trans.getType()).getShape(),
cast<VectorType>(mat.getType()).getElementType());
- return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
+ return arith::ExtUIOp::create(rewriter, loc, newType, trans);
}
- return rewriter.create<vector::TransposeOp>(loc, mat, perm);
+ return vector::TransposeOp::create(rewriter, loc, mat, perm);
};
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
@@ -1836,8 +1840,8 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
vAdd = rewriter.createOrFold<arith::AddIOp>(
loc, parentReduction.getVector(), op.getVector());
} else {
- vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
- op.getVector());
+ vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
+ op.getVector());
}
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
parentReduction.getAcc());
@@ -1925,7 +1929,7 @@ struct DropUnitDimFromElementwiseOps final
if (newVType == opVectorType)
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
- auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+ auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
newOperands.push_back(opSC);
}
@@ -2004,11 +2008,11 @@ struct DropUnitDimsFromTransposeOp final
Location loc = op.getLoc();
// Drop the unit dims via shape_cast.
- auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
- loc, sourceTypeWithoutUnitDims, op.getVector());
+ auto dropDimsShapeCast = vector::ShapeCastOp::create(
+ rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
// Create the new transpose.
auto transposeWithoutUnitDims =
- rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
+ vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
// Restore the unit dims via shape cast.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, op.getResultVectorType(), transposeWithoutUnitDims);
@@ -2059,7 +2063,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
// Create a new ForOp with that iter operand replaced.
auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
- return b.create<vector::ShapeCastOp>(loc, type, source);
+ return vector::ShapeCastOp::create(b, loc, type, source);
};
Value replacement =
@@ -2111,8 +2115,8 @@ struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
return failure();
- auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
- vAdd.getRhs());
+ auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
+ addLhs.getLhs(), vAdd.getRhs());
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
op.getAcc());
return success();
@@ -2154,8 +2158,8 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
Location loc = op.getLoc();
SmallVector<Value> extracted(numElems, nullptr);
for (auto [idx, extractedElem] : llvm::enumerate(extracted))
- extractedElem = rewriter.create<vector::ExtractOp>(
- loc, op.getVector(), static_cast<int64_t>(idx));
+ extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
+ static_cast<int64_t>(idx));
Value res = extracted.front();
for (auto extractedElem : llvm::drop_begin(extracted))
@@ -2234,8 +2238,8 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
return failure();
- return rewriter.create<vector::OuterProductOp>(
- mulOp->getLoc(), resType, broadcastedLhs.getSource(),
+ return vector::OuterProductOp::create(
+ rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..fceba65fa3e3a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -49,7 +49,7 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
getAffineConstantExpr(elementOffsets[dim.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
slicedIndices[pos] =
- builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
+ affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
}
return slicedIndices;
}
@@ -68,9 +68,9 @@ static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
auto start = indices.size() - offsets.size();
for (auto [i, offset] : llvm::enumerate(offsets)) {
if (offset != 0) {
- indices[start + i] = rewriter.create<arith::AddIOp>(
- loc, originalIndices[start + i],
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
+ indices[start + i] = arith::AddIOp::create(
+ rewriter, loc, originalIndices[start + i],
+ arith::ConstantIndexOp::create(rewriter, loc, offset));
}
}
return indices;
@@ -172,8 +172,9 @@ struct UnrollTransferReadPattern
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
// Prepare the result vector;
- Value result = rewriter.create<arith::ConstantOp>(
- loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
+ Value result =
+ arith::ConstantOp::create(rewriter, loc, sourceVectorType,
+ rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<Value> originalIndices(readOp.getIndices().begin(),
@@ -185,8 +186,8 @@ struct UnrollTransferReadPattern
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
- auto slicedRead = rewriter.create<vector::TransferReadOp>(
- loc, targetType, readOp.getBase(), indices,
+ auto slicedRead = vector::TransferReadOp::create(
+ rewriter, loc, targetType, readOp.getBase(), indices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
@@ -236,9 +237,10 @@ struct UnrollTransferWritePattern
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
- Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
- loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),
- indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
+ Operation *slicedWrite = vector::TransferWriteOp::create(
+ rewriter, loc, slicedVector,
+ resultTensor ? resultTensor : writeOp.getBase(), indices,
+ writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
// For the tensor case update the destination for the next transfer write.
if (!slicedWrite->getResults().empty())
resultTensor = slicedWrite->getResult(0);
@@ -348,8 +350,8 @@ struct UnrollContractionPattern
accCache[dstOffets] = newOp->getResult(0);
}
// Assemble back the accumulator into a single vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstVecType, rewriter.getZeroAttr(dstVecType));
+ Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
+ rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
@@ -427,8 +429,8 @@ struct UnrollMultiReductionPattern
accCache[destOffset] = result;
}
// Assemble back the accumulator into a single vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, reductionOp.getDestType(),
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, reductionOp.getDestType(),
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
@@ -468,8 +470,8 @@ struct UnrollElementwisePattern : public RewritePattern {
op, "expected input vector rank to match target shape rank");
Location loc = op->getLoc();
// Prepare the result vector.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstVecType, rewriter.getZeroAttr(dstVecType));
+ Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
+ rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t> strides(targetShape->size(), 1);
VectorType newVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
@@ -567,8 +569,9 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
// Prepare the result vector;
- Value result = rewriter.create<arith::ConstantOp>(
- loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
+ Value result =
+ arith::ConstantOp::create(rewriter, loc, originalVectorType,
+ rewriter.getZeroAttr(originalVectorType));
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
// Unroll the computation.
@@ -618,8 +621,9 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
// Prepare the result vector;
- Value result = rewriter.create<arith::ConstantOp>(
- loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
+ Value result =
+ arith::ConstantOp::create(rewriter, loc, sourceVectorType,
+ rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
@@ -638,8 +642,8 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
- auto slicedGather = rewriter.create<vector::GatherOp>(
- loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ auto slicedGather = vector::GatherOp::create(
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
@@ -671,8 +675,8 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
ArrayRef<int64_t> originalShape = vecType.getShape();
SmallVector<int64_t> strides(targetShape->size(), 1);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, vecType, rewriter.getZeroAttr(vecType));
+ Value result = arith::ConstantOp::create(rewriter, loc, vecType,
+ rewriter.getZeroAttr(vecType));
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalShape.size(), loadOp, options);
@@ -684,8 +688,8 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
SmallVector<Value> indices =
sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
- Value slicedLoad = rewriter.create<vector::LoadOp>(
- loc, targetVecType, loadOp.getBase(), indices);
+ Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
+ loadOp.getBase(), indices);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedLoad, result, offsets, strides);
}
@@ -727,7 +731,7 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, vector, offsets, *targetShape, strides);
- rewriter.create<vector::StoreOp>(loc, slice, base, indices);
+ vector::StoreOp::create(rewriter, loc, slice, base, indices);
}
rewriter.eraseOp(storeOp);
return success();
@@ -755,8 +759,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType = broadcastOp.getResultVectorType();
VectorType targetType =
resType.cloneWith(*targetShape, resType.getElementType());
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = arith::ConstantOp::create(rewriter, loc, resType,
+ rewriter.getZeroAttr(resType));
SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
SmallVector<int64_t> strides(originalShape.size(), 1);
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 9b055853fc8b0..c045063e8194f 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -333,7 +333,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
SmallVector<bool> inBoundsVal(readRank, true);
if (useInBoundsInsteadOfMasking) {
@@ -343,8 +343,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
ShapedType::isStatic(sourceShape[i]);
}
- auto transferReadOp = builder.create<vector::TransferReadOp>(
- loc,
+ auto transferReadOp = vector::TransferReadOp::create(
+ builder, loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
@@ -361,7 +361,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto maskType =
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
- builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
->getResult(0);
}
More information about the Mlir-commits
mailing list