[Mlir-commits] [mlir] [mlir][vector] Standardise `valueToStore` Naming Across Vector Ops (NFC) (PR #134206)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 2 23:51:51 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej Warzyński (banach-space)
<details>
<summary>Changes</summary>
This change standardises the naming convention for the argument
representing the value to store in various vector operations.
Specifically, it ensures that all vector ops storing a value—whether
into memory, a tensor, or another vector — use `valueToStore` for the
corresponding argument name.
Updated operations:
* `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`,
`vector.insert_strided_slice`.
For reference, here are operations that currently use `valueToStore`:
* `vector.store` `vector.scatter`, `vector.compressstore`,
`vector.maskedstore`.
This change is non-functional (NFC) and does not affect the
functionality of these operations.
Implements #<!-- -->131602
---
Patch is 36.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134206.diff
16 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+22-21)
- (modified) mlir/include/mlir/Interfaces/VectorInterfaces.td (+9-5)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+4-3)
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+6-5)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+2-2)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+2-2)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+44-20)
- (modified) mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp (+2-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+11-9)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+4-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+5-5)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+3-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..7fc56b1aa4e7e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -907,7 +907,7 @@ def Vector_InsertOp :
}];
let arguments = (ins
- AnyType:$source,
+ AnyType:$valueToStore,
AnyVectorOfAnyRank:$dest,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
let builders = [
// Builder to insert a scalar/rank-0 vector into a rank-0 vector.
- OpBuilder<(ins "Value":$source, "Value":$dest)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
- OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];
let extraClassDeclaration = extraPoisonClassDeclaration # [{
- Type getSourceType() { return getSource().getType(); }
+ Type getValueToStoreType() { return getValueToStore().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
@@ -946,8 +946,8 @@ def Vector_InsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
- attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+ attr-dict `:` type($valueToStore) `into` type($dest)
}];
let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
def Vector_ScalableInsertOp :
Vector_Op<"scalable.insert", [Pure,
- AllElementTypesMatch<["source", "dest"]>,
+ AllElementTypesMatch<["valueToStore", "dest"]>,
AllTypesMatch<["dest", "res"]>,
PredOpTrait<"position is a multiple of the source length.",
CPred<
"(getPos() % getSourceVectorType().getNumElements()) == 0"
>>]>,
- Arguments<(ins VectorOfRank<[1]>:$source,
+ Arguments<(ins VectorOfRank<[1]>:$valueToStore,
ScalableVectorOfRank<[1]>:$dest,
I64Attr:$pos)>,
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
}];
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getSource().getType());
+ return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank:$res)> {
let summary = "strided_slice operation";
let description = [{
- Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
+ Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
`offsets` integer array attribute, a k-sized `strides` integer array attribute
- and inserts the k-D source vector as a strided subvector at the proper offset
+ and inserts the k-D valueToStore vector as a strided subvector at the proper offset
into the n-D destination vector.
At the moment strides must contain only 1s.
Returns an n-D vector that is a copy of the n-D destination vector in which
- the last k-D dimensions contain the k-D source vector elements strided at
+ the last k-D dimensions contain the k-D valueToStore vector elements strided at
the proper location as specified by the offsets.
Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
}];
let assemblyFormat = [{
- $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+ $valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
}];
let builders = [
- OpBuilder<(ins "Value":$source, "Value":$dest,
+ OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
];
let extraClassDeclaration = [{
+ // TODO: Rename
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getSource().getType());
+ return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector,
+ Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index be939bad14b7b..8ea9d925b3790 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getVector",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the type of the vector that this operation operates on.
+ }],
+ /*retTy=*/"::mlir::VectorType",
+ /*methodName=*/"getVectorType",
+ /*args=*/(ins)
+ >,
InterfaceMethod<
/*desc=*/[{
Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getIndices",
/*args=*/(ins)
>,
+
InterfaceMethod<
/*desc=*/[{
Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
return $_op.getPermutationMap().getNumResults();
}
- /// Return the type of the vector that this operation operates on.
- ::mlir::VectorType getVectorType() {
- return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
- }
-
/// Return "true" if at least one of the vector dimensions is a broadcasted
/// dimension.
bool hasBroadcastDim() {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4be0fffe8b728..58b85bc0ea6ac 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
auto loc = insertOp.getLoc();
auto position = insertOp.getMixedPosition();
- Value source = insertOp.getSource();
+ Value source = insertOp.getValueToStore();
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f7375b8d13..847e7e2beebe9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
// We are going to mutate this 1D vector until it is either the final
// result (in the non-aggregate case) or the value that needs to be
// inserted into the aggregate result.
- Value sourceAggregate = adaptor.getSource();
+ Value sourceAggregate = adaptor.getValueToStore();
if (insertIntoInnermostDim) {
// Scalar-into-1D-vector case, so we know we will have to create a
// InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
}
// Insert the scalar into the 1D vector.
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
- loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
+ loc, sourceAggregate.getType(), sourceAggregate,
+ adaptor.getValueToStore(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
}
@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
- insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
+ insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..b9b598c02b4a2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
- xferOp.getVectorMutable().assign(loadedVec);
+ xferOp.getValueToStoreMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index bca77ba68fbd1..de2af69eba9ec 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (isa<VectorType>(insertOp.getSourceType()))
+ if (isa<VectorType>(insertOp.getValueToStoreType()))
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
return rewriter.notifyMatchFailure(insertOp,
"unsupported dest vector type");
// Special case for inserting scalar values into size-1 vectors.
- if (insertOp.getSourceType().isIntOrFloat() &&
+ if (insertOp.getValueToStoreType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
+ rewriter.replaceOp(insertOp, adaptor.getValueToStore());
return success();
}
@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
insertOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::InsertOp::kPoisonIndex,
insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ insertOp, insertOp.getDest(), adaptor.getValueToStore(),
+ sanitizedIndex);
}
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..62a148d2b7e62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
auto loc = writeOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
- auto inputSMETiles = adaptor.getVector();
+ auto inputSMETiles = adaptor.getValueToStore();
Value destTensorOrMemref = writeOp.getSource();
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
rewriter.setInsertionPointToStart(storeLoop.getBody());
// For each sub-tile of the multi-tile `vectorType`.
- auto inputSMETiles = adaptor.getVector();
+ auto inputSMETiles = adaptor.getValueToStore();
auto tileSliceIndex = storeLoop.getInductionVar();
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index acfd9683f01f4..20e4e3cee7ed4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
if (failed(maybeNewLoop))
return WalkResult::interrupt();
- transferWrite.getVectorMutable().assign(
+ transferWrite.getValueToStoreMutable().assign(
maybeNewLoop->getOperation()->getResults().back());
changed = true;
// Need to interrupt and restart because erasing the loop messes up
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c8b1b85ef5a3..5afe378463d13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
- rewriter.getBoolArrayAttr(
- SmallVector<bool>(vector.getType().getRank(), false)));
+ rewriter.getBoolArrayAttr(SmallVector<bool>(
+ dyn_cast<VectorType>(vector.getType()).getRank(), false)));
rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..98d98f067de14 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1555,7 +1555,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
return failure();
// Case 2.a. early-exit fold.
- res = nextInsertOp.getSource();
+ res = nextInsertOp.getValueToStore();
// Case 2.b. if internal transposition is present, canFold will be false.
return success(canFold());
}
@@ -1579,7 +1579,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
extractPosition.begin() + insertedPos.size());
extractedRank = extractPosition.size() - sentinels.size();
// Case 3.a. early-exit fold (break and delegate to post-while path).
- res = nextInsertOp.getSource();
+ res = nextInsertOp.getValueToStore();
// Case 3.b. if internal transposition is present, canFold will be false.
return success();
}
@@ -1936,7 +1936,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertRankDiff))
return Value();
}
- extractOp.getVectorMutable().assign(insertOp.getSource());
+ extractOp.getVectorMutable().assign(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(offsetDiffs);
@@ -2958,7 +2958,7 @@ LogicalResult InsertOp::verify() {
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank no greater than dest vector rank");
- auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
+ auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
static_cast<unsigned>(destVectorType.getRank())))
@@ -2994,12 +2994,13 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
- auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
+ auto srcVecType =
+ llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
srcVecType.getNumElements())
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(
- insertOp, insertOp.getDestVectorType(), insertOp.getSource());
+ insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
return success();
}
};
@@ -3011,7 +3012,7 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
- auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+ auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
if (!srcSplat || !dstSplat)
@@ -3100,17 +3101,17 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
// (type mismatch).
- if (getNumIndices() == 0 && getSourceType() == getType())
- return getSource();
- SmallVector<Value> operands = {getSource(), getDest()};
+ if (getNumIndices() == 0 && getValueToStoreType() == getType())
+ return getValueToStore();
+ SmallVector<Value> operands = {getValueToStore(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
- adaptor.getDest(),
- vectorSizeFoldThreshold)) {
+ if (auto res = foldDenseElementsAttrDestInsertOp(
+ *this, adaptor.getValueToStore(), adaptor.getDest(),
+ vectorSizeFoldThreshold)) {
return res;
}
@@ -3291,7 +3292,7 @@ class FoldInsertStridedSliceSplat final
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto srcSplatOp =
- insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+ insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
auto destSplatOp =
insertStridedSliceOp.getDes...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/134206
More information about the Mlir-commits
mailing list