[Mlir-commits] [mlir] [mlir][vector] Standardise `valueToStore` Naming Across Vector Ops (NFC) (PR #134206)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Apr 2 23:51:19 PDT 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/134206
This change standardises the naming convention for the argument
representing the value to store in various vector operations.
Specifically, it ensures that all vector ops storing a value—whether
into memory, a tensor, or another vector — use `valueToStore` for the
corresponding argument name.
Updated operations:
* `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`,
`vector.insert_strided_slice`.
For reference, here are operations that currently use `valueToStore`:
* `vector.store` `vector.scatter`, `vector.compressstore`,
`vector.maskedstore`.
This change is non-functional (NFC) and does not affect the
functionality of these operations.
Implements #131602
>From 3f6ef5237286c1107569e0f658d92e2f4d243042 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Apr 2025 17:49:09 +0100
Subject: [PATCH] [mlir][vector] Standardise `valueToStore` Naming Across
Vector Ops (NFC)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This change standardises the naming convention for the argument
representing the value to store in various vector operations.
Specifically, it ensures that all vector ops storing a value—whether
into memory, a tensor, or another vector — use `valueToStore` for the
corresponding argument name.
Updated operations:
* `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`,
`vector.insert_strided_slice`.
For reference, here are operations that currently use `valueToStore`:
* `vector.store` `vector.scatter`, `vector.compressstore`,
`vector.maskedstore`.
This change is non-functional (NFC) and does not affect the
functionality of these operations.
Implements #131602
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 43 +++++++------
.../mlir/Interfaces/VectorInterfaces.td | 14 ++--
.../VectorToArmSME/VectorToArmSME.cpp | 2 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 7 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 2 +-
.../VectorToSPIRV/VectorToSPIRV.cpp | 11 ++--
.../ArmSME/Transforms/VectorLegalization.cpp | 4 +-
.../Dialect/Linalg/Transforms/Hoisting.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 4 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 64 +++++++++++++------
.../Transforms/SubsetOpInterfaceImpl.cpp | 4 +-
.../Vector/Transforms/VectorDistribute.cpp | 20 +++---
.../Transforms/VectorDropLeadUnitDim.cpp | 8 +--
...sertExtractStridedSliceRewritePatterns.cpp | 10 +--
.../Vector/Transforms/VectorLinearize.cpp | 6 +-
.../Vector/Transforms/VectorTransforms.cpp | 6 +-
16 files changed, 120 insertions(+), 87 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..7fc56b1aa4e7e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -907,7 +907,7 @@ def Vector_InsertOp :
}];
let arguments = (ins
- AnyType:$source,
+ AnyType:$valueToStore,
AnyVectorOfAnyRank:$dest,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
let builders = [
// Builder to insert a scalar/rank-0 vector into a rank-0 vector.
- OpBuilder<(ins "Value":$source, "Value":$dest)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];
let extraClassDeclaration = extraPoisonClassDeclaration # [{
- Type getSourceType() { return getSource().getType(); }
+ Type getValueToStoreType() { return getValueToStore().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
@@ -946,8 +946,8 @@ def Vector_InsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
- attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+ attr-dict `:` type($valueToStore) `into` type($dest)
}];
let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
def Vector_ScalableInsertOp :
Vector_Op<"scalable.insert", [Pure,
- AllElementTypesMatch<["source", "dest"]>,
+ AllElementTypesMatch<["valueToStore", "dest"]>,
AllTypesMatch<["dest", "res"]>,
PredOpTrait<"position is a multiple of the source length.",
CPred<
"(getPos() % getSourceVectorType().getNumElements()) == 0"
>>]>,
- Arguments<(ins VectorOfRank<[1]>:$source,
+ Arguments<(ins VectorOfRank<[1]>:$valueToStore,
ScalableVectorOfRank<[1]>:$dest,
I64Attr:$pos)>,
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
}];
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getSource().getType());
+ return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank:$res)> {
let summary = "strided_slice operation";
let description = [{
- Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
+ Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
`offsets` integer array attribute, a k-sized `strides` integer array attribute
- and inserts the k-D source vector as a strided subvector at the proper offset
+ and inserts the k-D valueToStore vector as a strided subvector at the proper offset
into the n-D destination vector.
At the moment strides must contain only 1s.
Returns an n-D vector that is a copy of the n-D destination vector in which
- the last k-D dimensions contain the k-D source vector elements strided at
+ the last k-D dimensions contain the k-D valueToStore vector elements strided at
the proper location as specified by the offsets.
Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
}];
let assemblyFormat = [{
- $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
}];
let builders = [
- OpBuilder<(ins "Value":$source, "Value":$dest,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
];
let extraClassDeclaration = [{
+ // TODO: Rename
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getSource().getType());
+ return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector,
+ Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index be939bad14b7b..8ea9d925b3790 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getVector",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the type of the vector that this operation operates on.
+ }],
+ /*retTy=*/"::mlir::VectorType",
+ /*methodName=*/"getVectorType",
+ /*args=*/(ins)
+ >,
InterfaceMethod<
/*desc=*/[{
Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getIndices",
/*args=*/(ins)
>,
+
InterfaceMethod<
/*desc=*/[{
Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
return $_op.getPermutationMap().getNumResults();
}
- /// Return the type of the vector that this operation operates on.
- ::mlir::VectorType getVectorType() {
- return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
- }
-
/// Return "true" if at least one of the vector dimensions is a broadcasted
/// dimension.
bool hasBroadcastDim() {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4be0fffe8b728..58b85bc0ea6ac 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
auto loc = insertOp.getLoc();
auto position = insertOp.getMixedPosition();
- Value source = insertOp.getSource();
+ Value source = insertOp.getValueToStore();
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f7375b8d13..847e7e2beebe9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
// We are going to mutate this 1D vector until it is either the final
// result (in the non-aggregate case) or the value that needs to be
// inserted into the aggregate result.
- Value sourceAggregate = adaptor.getSource();
+ Value sourceAggregate = adaptor.getValueToStore();
if (insertIntoInnermostDim) {
// Scalar-into-1D-vector case, so we know we will have to create a
// InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
}
// Insert the scalar into the 1D vector.
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
- loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
+ loc, sourceAggregate.getType(), sourceAggregate,
+ adaptor.getValueToStore(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
}
@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
- insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
+ insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..b9b598c02b4a2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
- xferOp.getVectorMutable().assign(loadedVec);
+ xferOp.getValueToStoreMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index bca77ba68fbd1..de2af69eba9ec 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (isa<VectorType>(insertOp.getSourceType()))
+ if (isa<VectorType>(insertOp.getValueToStoreType()))
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
return rewriter.notifyMatchFailure(insertOp,
"unsupported dest vector type");
// Special case for inserting scalar values into size-1 vectors.
- if (insertOp.getSourceType().isIntOrFloat() &&
+ if (insertOp.getValueToStoreType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
+ rewriter.replaceOp(insertOp, adaptor.getValueToStore());
return success();
}
@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
insertOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::InsertOp::kPoisonIndex,
insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ insertOp, insertOp.getDest(), adaptor.getValueToStore(),
+ sanitizedIndex);
}
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..62a148d2b7e62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
auto loc = writeOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
- auto inputSMETiles = adaptor.getVector();
+ auto inputSMETiles = adaptor.getValueToStore();
Value destTensorOrMemref = writeOp.getSource();
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
rewriter.setInsertionPointToStart(storeLoop.getBody());
// For each sub-tile of the multi-tile `vectorType`.
- auto inputSMETiles = adaptor.getVector();
+ auto inputSMETiles = adaptor.getValueToStore();
auto tileSliceIndex = storeLoop.getInductionVar();
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index acfd9683f01f4..20e4e3cee7ed4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
if (failed(maybeNewLoop))
return WalkResult::interrupt();
- transferWrite.getVectorMutable().assign(
+ transferWrite.getValueToStoreMutable().assign(
maybeNewLoop->getOperation()->getResults().back());
changed = true;
// Need to interrupt and restart because erasing the loop messes up
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c8b1b85ef5a3..5afe378463d13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
- rewriter.getBoolArrayAttr(
- SmallVector<bool>(vector.getType().getRank(), false)));
+ rewriter.getBoolArrayAttr(SmallVector<bool>(
+ dyn_cast<VectorType>(vector.getType()).getRank(), false)));
rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..98d98f067de14 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1555,7 +1555,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
- res = nextInsertOp.getSource();
+ res = nextInsertOp.getValueToStore();
// Case 2.b. if internal transposition is present, canFold will be false.
return success(canFold());
}
@@ -1579,7 +1579,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
extractPosition.begin() + insertedPos.size());
extractedRank = extractPosition.size() - sentinels.size();
// Case 3.a. early-exit fold (break and delegate to post-while path).
- res = nextInsertOp.getSource();
+ res = nextInsertOp.getValueToStore();
// Case 3.b. if internal transposition is present, canFold will be false.
return success();
}
@@ -1936,7 +1936,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertRankDiff))
return Value();
}
- extractOp.getVectorMutable().assign(insertOp.getSource());
+ extractOp.getVectorMutable().assign(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(offsetDiffs);
@@ -2958,7 +2958,7 @@ LogicalResult InsertOp::verify() {
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank no greater than dest vector rank");
- auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
+ auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
@@ -2994,12 +2994,13 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
- auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
+ auto srcVecType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
srcVecType.getNumElements())
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.getSource());
+ insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
return success();
}
};
@@ -3011,7 +3012,7 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
- auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+ auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
if (!srcSplat || !dstSplat)
@@ -3100,17 +3101,17 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
- if (getNumIndices() == 0 && getSourceType() == getType())
- return getSource();
- SmallVector<Value> operands = {getSource(), getDest()};
+ if (getNumIndices() == 0 && getValueToStoreType() == getType())
+ return getValueToStore();
+ SmallVector<Value> operands = {getValueToStore(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
- adaptor.getDest(),
- vectorSizeFoldThreshold)) {
+ if (auto res = foldDenseElementsAttrDestInsertOp(
+ *this, adaptor.getValueToStore(), adaptor.getDest(),
+ vectorSizeFoldThreshold)) {
return res;
}
@@ -3291,7 +3292,7 @@ class FoldInsertStridedSliceSplat final
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto srcSplatOp =
- insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+ insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
auto destSplatOp =
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
@@ -3316,7 +3317,7 @@ class FoldInsertStridedSliceOfExtract final
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto extractStridedSliceOp =
- insertStridedSliceOp.getSource()
+ insertStridedSliceOp.getValueToStore()
.getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
@@ -3365,7 +3366,7 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();
- TypedValue<VectorType> sourceValue = op.getSource();
+ TypedValue<VectorType> sourceValue = op.getValueToStore();
Attribute sourceCst;
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
return failure();
@@ -3425,7 +3426,7 @@ void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
if (getSourceVectorType() == getDestVectorType())
- return getSource();
+ return getValueToStore();
return {};
}
@@ -3691,7 +3692,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
}
// The extract element chunk is a subset of the insert element.
if (!disjoint && !patialoverlap) {
- op.setOperand(insertOp.getSource());
+ op.setOperand(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
@@ -4349,6 +4350,13 @@ Type TransferReadOp::getExpectedMaskType() {
return inferTransferOpMaskType(getVectorType(), getPermutationMap());
}
+//===----------------------------------------------------------------------===//
+// TransferReadOp: VectorTransferOpInterface methods.
+//===----------------------------------------------------------------------===//
+VectorType TransferReadOp::getVectorType() {
+ return cast<VectorType>(getVector().getType());
+}
+
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
@@ -4739,7 +4747,9 @@ LogicalResult TransferWriteOp::verify() {
[&](Twine t) { return emitOpError(t); });
}
-// MaskableOpInterface methods.
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: MaskableOpInterface methods.
+//===----------------------------------------------------------------------===//
/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes.
@@ -4747,6 +4757,17 @@ Type TransferWriteOp::getExpectedMaskType() {
return inferTransferOpMaskType(getVectorType(), getPermutationMap());
}
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: VectorTransferOpInterface methods.
+//===----------------------------------------------------------------------===//
+Value TransferWriteOp::getVector() { return getOperand(0); }
+VectorType TransferWriteOp::getVectorType() {
+ return cast<VectorType>(getValueToStore().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: fold methods.
+//===----------------------------------------------------------------------===//
/// Fold:
/// ```
/// %t1 = ...
@@ -4863,6 +4884,9 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
return memref::foldMemRefCast(*this);
}
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: other methods.
+//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
@@ -4871,7 +4895,7 @@ void TransferWriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (llvm::isa<MemRefType>(getShapedType()))
- effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
+ effects.emplace_back(MemoryEffects::Write::get(), &getValueToStoreMutable(),
SideEffects::DefaultResource::get());
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
index b450d5b78a466..7fae5460776d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
@@ -45,11 +45,11 @@ struct TransferWriteOpSubsetInsertionOpInterface
: public SubsetInsertionOpInterface::ExternalModel<
TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
OpOperand &getSourceOperand(Operation *op) const {
- return cast<vector::TransferWriteOp>(op).getVectorMutable();
+ return cast<vector::TransferWriteOp>(op).getValueToStoreMutable();
}
OpOperand &getDestinationOperand(Operation *op) const {
- return cast<vector::TransferWriteOp>(op).getSourceMutable();
+ return cast<vector::TransferWriteOp>(op).getValueToStoreMutable();
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..19f408ad1b570 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -496,7 +496,8 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
- newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+ newWriteOp.getValueToStoreMutable().assign(
+ newWarpOp.getResult(newRetIndices[0]));
rewriter.eraseOp(writeOp);
rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
return success();
@@ -559,7 +560,8 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
- newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+ newWriteOp.getValueToStoreMutable().assign(
+ newWarpOp.getResult(newRetIndices[0]));
if (maybeMaskType)
newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
return newWriteOp;
@@ -1299,9 +1301,9 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
// Yield destination vector, source scalar and position from warp op.
SmallVector<Value> additionalResults{insertOp.getDest(),
- insertOp.getSource()};
- SmallVector<Type> additionalResultTypes{distrType,
- insertOp.getSource().getType()};
+ insertOp.getValueToStore()};
+ SmallVector<Type> additionalResultTypes{
+ distrType, insertOp.getValueToStore().getType()};
additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
additionalResultTypes.append(
SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
@@ -1393,8 +1395,8 @@ struct WarpOpInsert : public WarpDistributionPattern {
// out of the warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
- {insertOp.getSourceType(), insertOp.getDestVectorType()},
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
@@ -1422,7 +1424,7 @@ struct WarpOpInsert : public WarpDistributionPattern {
assert(distrDestDim != -1 && "could not find distributed dimension");
// Compute the distributed source vector type.
- VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
+ VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
// E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
// Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
@@ -1439,7 +1441,7 @@ struct WarpOpInsert : public WarpDistributionPattern {
// Yield source and dest vectors from warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
{distrSrcType, distrDestType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index b53aa997c9014..fda3baf3aa390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -122,7 +122,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
Location loc = insertOp.getLoc();
Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getSource(), splatZero(srcDropCount));
+ loc, insertOp.getValueToStore(), splatZero(srcDropCount));
Value newDstVector = rewriter.create<vector::ExtractOp>(
loc, insertOp.getDest(), splatZero(dstDropCount));
@@ -148,7 +148,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
PatternRewriter &rewriter) const override {
- Type oldSrcType = insertOp.getSourceType();
+ Type oldSrcType = insertOp.getValueToStoreType();
Type newSrcType = oldSrcType;
int64_t oldSrcRank = 0, newSrcRank = 0;
if (auto type = dyn_cast<VectorType>(oldSrcType)) {
@@ -168,10 +168,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
// Trim leading one dimensions from both operands.
Location loc = insertOp.getLoc();
- Value newSrcVector = insertOp.getSource();
+ Value newSrcVector = insertOp.getValueToStore();
if (oldSrcRank != 0) {
newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.getSource(), splatZero(srcDropCount));
+ loc, insertOp.getValueToStore(), splatZero(srcDropCount));
}
Value newDstVector = rewriter.create<vector::ExtractOp>(
loc, insertOp.getDest(), splatZero(dstDropCount));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 82a985c9e5824..d834a99076834 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -61,7 +61,7 @@ class DecomposeDifferentRankInsertStridedSlice
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.getSource(), extracted,
+ loc, op.getValueToStore(), extracted,
getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.getStrides(), /*dropFront=*/0));
@@ -111,7 +111,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
return failure();
if (srcType == dstType) {
- rewriter.replaceOp(op, op.getSource());
+ rewriter.replaceOp(op, op.getValueToStore());
return success();
}
@@ -131,8 +131,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
- Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
- op.getSource(), offsets);
+ Value scaledSource = rewriter.create<ShuffleOp>(
+ loc, op.getValueToStore(), op.getValueToStore(), offsets);
// 2. Create a mask where we take the value from scaledSource of dest
// depending on the offset.
@@ -156,7 +156,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource =
- rewriter.create<ExtractOp>(loc, op.getSource(), idx);
+ rewriter.create<ExtractOp>(loc, op.getValueToStore(), idx);
if (isa<VectorType>(extractedSource.getType())) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 9dccc005322eb..a009aa03aaf64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -439,7 +439,7 @@ struct LinearizeVectorInsert final
return rewriter.notifyMatchFailure(insertOp,
"scalable vectors are not supported.");
- if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+ if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
targetVectorBitWidth))
return rewriter.notifyMatchFailure(
insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -448,7 +448,7 @@ struct LinearizeVectorInsert final
if (insertOp.hasDynamicPosition())
return rewriter.notifyMatchFailure(insertOp,
"dynamic position is not supported.");
- auto srcTy = insertOp.getSourceType();
+ auto srcTy = insertOp.getValueToStoreType();
auto srcAsVec = dyn_cast<VectorType>(srcTy);
uint64_t srcSize = 0;
if (srcAsVec) {
@@ -484,7 +484,7 @@ struct LinearizeVectorInsert final
// [offset+srcNumElements, end)
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
+ insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..d50d5fe96f49a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -748,7 +748,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
return failure();
// Only vector sources are supported for now.
- auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+ auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
if (!insertSrcType)
return failure();
@@ -759,7 +759,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
dstDims.back() =
@@ -850,7 +850,7 @@ struct BubbleUpBitCastForStridedSliceInsert
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
- bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
SmallVector<int64_t> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
More information about the Mlir-commits
mailing list