[Mlir-commits] [mlir] [mlir][vector] Use `source` as the source argument name (PR #158258)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 12 02:57:45 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
This patch updates the following ops to use `source` (instead of `vector`)
as the name for their source argument:
* `vector.extract`
* `vector.scalable.extract`
* `vector.extract_strided_slice`
This change ensures naming consistency with the "builders" for these Ops
that already use the name `source` rather than `vector`. It also
addresses part of:
* https://github.com/llvm/llvm-project/issues/131602
Specifically, it ensures that we use `source` and `dest` for read and
write operations, respectively (as opposed to `vector` and `dest`).
---
Patch is 34.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158258.diff
18 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+24-7)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+2-2)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+4-4)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp (+1-1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+2-2)
- (modified) mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+34-34)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+4-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+4-4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 65ba7e0ad549f..f52075886326b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -675,7 +675,7 @@ def Vector_ExtractOp :
}];
let arguments = (ins
- AnyVectorOfAnyRank:$vector,
+ AnyVectorOfAnyRank:$source,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
@@ -692,7 +692,7 @@ def Vector_ExtractOp :
let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
/// Return a vector with all the static and dynamic position indices.
@@ -709,12 +709,17 @@ def Vector_ExtractOp :
bool hasDynamicPosition() {
return !getDynamicPosition().empty();
}
+
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let assemblyFormat = [{
- $vector ``
+ $source ``
custom<DynamicIndexList>($dynamic_position, $static_position)
- attr-dict `:` type($result) `from` type($vector)
+ attr-dict `:` type($result) `from` type($source)
}];
let hasCanonicalizer = 1;
@@ -972,6 +977,10 @@ def Vector_ScalableInsertOp :
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1023,6 +1032,10 @@ def Vector_ScalableExtractOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
}
@@ -1174,7 +1187,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "extract_strided_slice operation";
@@ -1209,7 +1222,7 @@ def Vector_ExtractStridedSliceOp :
];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
+ return ::llvm::cast<VectorType>(getSource().getType());
}
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
@@ -1217,11 +1230,15 @@ def Vector_ExtractStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
+ /// Wrapper for getSource, which replaced getVector.
+ [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+ return getSource();
+ }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
}
// TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 9efa34a9a3acc..4e1da39c29260 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
auto loc = extractOp.getLoc();
auto position = extractOp.getMixedPosition();
- Value sourceVector = extractOp.getVector();
+ Value sourceVector = extractOp.getSource();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (position.empty()) {
@@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1d1904f717335..e1a22ec6d799b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
// Find the vector.transer_read whose result vector is being sliced.
- auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+ auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return rewriter.notifyMatchFailure(op, "no transfer read");
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fbac4925dc1d..e7266740894b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
positionVec.push_back(rewriter.getZeroAttr(idxType));
}
- Value extracted = adaptor.getVector();
+ Value extracted = adaptor.getSource();
if (extractsAggregate) {
ArrayRef<OpFoldResult> position(positionVec);
if (extractsScalar) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 508f4e25326eb..c45c45e4712f3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
/// Return the vector from which newly generated ExtracOps will extract.
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
- return extractOp.getVector();
+ return extractOp.getSource();
return xferOp.getVector();
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c861935b4bc18..1c311d0312aaa 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
+ if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+ rewriter.replaceOp(extractOp, adaptor.getSource());
return success();
}
@@ -201,7 +201,7 @@ struct VectorExtractOpConvert final
extractOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
+ extractOp, dstType, adaptor.getSource(),
rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
@@ -209,7 +209,7 @@ struct VectorExtractOpConvert final
vector::ExtractOp::kPoisonIndex,
extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ extractOp, dstType, adaptor.getSource(), sanitizedIndex);
}
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9bf026563c255..9196d2ef79592 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
return rewriter.notifyMatchFailure(
extractOp, "extracted type is not a 1-D scalable vector type");
- auto *extendOp = extractOp.getVector().getDefiningOp();
+ auto *extendOp = extractOp.getSource().getDefiningOp();
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
extendOp))
return rewriter.notifyMatchFailure(extractOp,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 1c0eced43dc00..576c92b375ff3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 36434cf2d2ae2..e1dc40d6d37d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
// Check that the vector to extract from is a BlockArgument.
- auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
if (!blockArg)
return WalkResult::advance();
@@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
return WalkResult::advance();
rewriter.modifyOpInPlace(broadcast, [&] {
- extractOp.getVectorMutable().assign(initArg->get());
+ extractOp.getSourceMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index b392ffeb13de6..050bbac2293e9 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..8d6e263934fb4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1309,7 +1309,7 @@ LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
- auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+ auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
@@ -1379,7 +1379,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
- if (!extractOp.getVector().getDefiningOp<ExtractOp>())
+ if (!extractOp.getSource().getDefiningOp<ExtractOp>())
return failure();
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ExtractOp currentOp = extractOp;
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
- while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
+ while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
// TODO: Canonicalization for dynamic position not implemented yet.
if (currentOp.hasDynamicPosition())
@@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
- extractOp.setOperand(0, currentOp.getVector());
+ extractOp.setOperand(0, currentOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
std::reverse(globalPosition.begin(), globalPosition.end());
@@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
return Value();
// If we can't fold (either internal transposition, or nothing to fold), bail.
- bool nothingToFold = (source == extractOp.getVector());
+ bool nothingToFold = (source == extractOp.getSource());
if (nothingToFold || !canFold())
return Value();
@@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(
ArrayRef(extractPosition).take_front(extractedRank));
- extractOp.getVectorMutable().assign(source);
+ extractOp.getSourceMutable().assign(source);
return extractOp.getResult();
}
@@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
if (extractOp.hasDynamicPosition())
return Value();
- Value valueToExtractFrom = extractOp.getVector();
+ Value valueToExtractFrom = extractOp.getSource();
updateStateForNextIteration(valueToExtractFrom);
while (nextInsertOp || nextTransposeOp) {
// Case 1. If we hit a transpose, just compose the map and iterate.
@@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) {
/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
if (!defOp || !isBroadcastLike(defOp))
return Value();
@@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+ auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
if (!shuffleOp)
return Value();
@@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
if (extractOp.hasDynamicPosition())
return Value();
- auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
@@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
auto extractStridedSliceOp =
- extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
+ extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return Value();
@@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
assert(extractedPos.size() >= sliceOffsets.size());
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
- extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+ extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
@@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
llvm::isa<VectorType>(extractOp.getType())
? llvm::cast<VectorType>(extractOp.getType()).getRank()
: 0;
- auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+ auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
if (!insertOp)
return Value();
@@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertRankDiff))
return Value();
}
- extractOp.getVectorMutable().assign(insertOp.getValueToStore());
+ extractOp.getSourceMutable().assign(insertOp.getValueToStore());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setStaticPosition(offsetDiffs);
@@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return {};
// Look for extract(from_elements).
- auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+ auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return {};
@@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
// mismatch).
- if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
- return getVector();
- if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
+ return getSource();
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
return res;
// Fold `arith.constant` indices into the `vector.extract` operation.
// Do not stop here as this fold may enable subsequent folds that require
// constant indices.
- SmallVector<Value> operands = {getVector()};
+ SmallVector<Value> operands = {getSource()};
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
- if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
@@ -2187,7 +2187,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getSource().getDefiningOp();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();
@@ -2210,7 +2210,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto createMaskOp =
- extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();
@@ -2271,7 +2271,7 @@ class ExtractOpFromCr...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/158258
More information about the Mlir-commits
mailing list