[Mlir-commits] [mlir] [mlir][vector] Use `source` as the source argument name (PR #158258)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Sep 12 03:55:27 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/158258

>From 36fe297dd3d5537bedb9d737e356dac2d98fa726 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 12 Sep 2025 09:44:53 +0000
Subject: [PATCH] [mlir][vector] Use `source` as the source argument name

This patch updates the following ops to use `source` (instead of `vector`)
as the name for their source argument:
  * `vector.extract`
  * `vector.scalable.extract`
  * `vector.extract_strided_slice`

This change ensures naming consistency with the "builders" for these Ops
that already use the name `source` rather than `vector`. It also
addresses part of:
  * https://github.com/llvm/llvm-project/issues/131602

Specifically, it ensures that we use `source` and `dest` for read and
write operations, respectively (as opposed to `vector` and `dest`).
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 27 ++++++--
 .../VectorToArmSME/VectorToArmSME.cpp         |  4 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  2 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  2 +-
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  8 +--
 .../ArmSME/Transforms/OuterProductFusion.cpp  |  2 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  |  2 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  4 +-
 .../NVGPU/Transforms/CreateAsyncGroups.cpp    |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 68 +++++++++----------
 .../Vector/Transforms/VectorDistribute.cpp    |  8 +--
 .../Transforms/VectorDropLeadUnitDim.cpp      |  2 +-
 .../Transforms/VectorEmulateNarrowType.cpp    |  2 +-
 ...sertExtractStridedSliceRewritePatterns.cpp |  8 +--
 .../Vector/Transforms/VectorLinearize.cpp     |  8 +--
 .../Transforms/VectorTransferOpTransforms.cpp |  2 +-
 .../Vector/Transforms/VectorTransforms.cpp    |  8 +--
 18 files changed, 87 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 65ba7e0ad549f..f6b151f6ed644 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -675,7 +675,7 @@ def Vector_ExtractOp :
   }];
 
   let arguments = (ins
-    AnyVectorOfAnyRank:$vector,
+    AnyVectorOfAnyRank:$source,
     Variadic<Index>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
@@ -692,7 +692,7 @@ def Vector_ExtractOp :
 
   let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getVector().getType());
+      return ::llvm::cast<VectorType>(getSource().getType());
     }
 
     /// Return a vector with all the static and dynamic position indices.
@@ -709,12 +709,17 @@ def Vector_ExtractOp :
     bool hasDynamicPosition() {
       return !getDynamicPosition().empty();
     }
+
+    /// Wrapper for getSource, which replaced getVector.
+    [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+      return getSource();
+    }
   }];
 
   let assemblyFormat = [{
-    $vector ``
+    $source ``
     custom<DynamicIndexList>($dynamic_position, $static_position)
-    attr-dict `:` type($result) `from` type($vector)
+    attr-dict `:` type($result) `from` type($source)
   }];
 
   let hasCanonicalizer = 1;
@@ -1023,6 +1028,10 @@ def Vector_ScalableExtractOp :
     VectorType getResultVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
+    /// Wrapper for getSource, which replaced getVector.
+    [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+      return getSource();
+    }
   }];
 }
 
@@ -1174,7 +1183,7 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
+    Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
                I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
     Results<(outs AnyVectorOfNonZeroRank)> {
   let summary = "extract_strided_slice operation";
@@ -1209,7 +1218,7 @@ def Vector_ExtractStridedSliceOp :
   ];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getVector().getType());
+      return ::llvm::cast<VectorType>(getSource().getType());
     }
     void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
@@ -1217,11 +1226,15 @@ def Vector_ExtractStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
+    /// Wrapper for getSource, which replaced getVector.
+    [[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
+      return getSource();
+    }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
 }
 
 // TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 9efa34a9a3acc..4e1da39c29260 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
     auto loc = extractOp.getLoc();
     auto position = extractOp.getMixedPosition();
 
-    Value sourceVector = extractOp.getVector();
+    Value sourceVector = extractOp.getSource();
 
     // Extract entire vector. Should be handled by folder, but just to be safe.
     if (position.empty()) {
@@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
       return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
 
     auto createMaskOp =
-        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+        extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
     if (!createMaskOp)
       return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
 
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1d1904f717335..e1a22ec6d799b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
 
   // Find the vector.transer_read whose result vector is being sliced.
-  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+  auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
   if (!transferReadOp)
     return rewriter.notifyMatchFailure(op, "no transfer read");
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fbac4925dc1d..e7266740894b1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1131,7 +1131,7 @@ class VectorExtractOpConversion
       positionVec.push_back(rewriter.getZeroAttr(idxType));
     }
 
-    Value extracted = adaptor.getVector();
+    Value extracted = adaptor.getSource();
     if (extractsAggregate) {
       ArrayRef<OpFoldResult> position(positionVec);
       if (extractsScalar) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 508f4e25326eb..c45c45e4712f3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
   /// Return the vector from which newly generated ExtracOps will extract.
   Value getDataVector(TransferWriteOp xferOp) const {
     if (auto extractOp = getExtractOp(xferOp))
-      return extractOp.getVector();
+      return extractOp.getSource();
     return xferOp.getVector();
   }
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c861935b4bc18..1c311d0312aaa 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
     if (!dstType)
       return failure();
 
-    if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
-      rewriter.replaceOp(extractOp, adaptor.getVector());
+    if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
+      rewriter.replaceOp(extractOp, adaptor.getSource());
       return success();
     }
 
@@ -201,7 +201,7 @@ struct VectorExtractOpConvert final
             extractOp,
             "Static use of poison index handled elsewhere (folded to poison)");
       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-          extractOp, dstType, adaptor.getVector(),
+          extractOp, dstType, adaptor.getSource(),
           rewriter.getI32ArrayAttr(id.value()));
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
@@ -209,7 +209,7 @@ struct VectorExtractOpConvert final
           vector::ExtractOp::kPoisonIndex,
           extractOp.getSourceVectorType().getNumElements());
       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-          extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+          extractOp, dstType, adaptor.getSource(), sanitizedIndex);
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9bf026563c255..9196d2ef79592 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
       return rewriter.notifyMatchFailure(
           extractOp, "extracted type is not a 1-D scalable vector type");
 
-    auto *extendOp = extractOp.getVector().getDefiningOp();
+    auto *extendOp = extractOp.getSource().getDefiningOp();
     if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
             extendOp))
       return rewriter.notifyMatchFailure(extractOp,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 1c0eced43dc00..576c92b375ff3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
                                 PatternRewriter &rewriter) const override {
     auto loc = extractOp.getLoc();
     auto createMaskOp =
-        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+        extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
     if (!createMaskOp)
       return rewriter.notifyMatchFailure(
           extractOp, "extract not from vector.create_mask op");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 36434cf2d2ae2..e1dc40d6d37d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
         return WalkResult::advance();
 
       // Check that the vector to extract from is a BlockArgument.
-      auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+      auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
       if (!blockArg)
         return WalkResult::advance();
 
@@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
           return WalkResult::advance();
 
       rewriter.modifyOpInPlace(broadcast, [&] {
-        extractOp.getVectorMutable().assign(initArg->get());
+        extractOp.getSourceMutable().assign(initArg->get());
       });
       loop.moveOutOfLoop(extractOp);
       rewriter.moveOpAfter(broadcast, loop);
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index b392ffeb13de6..050bbac2293e9 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
   if (auto extractOp =
           transferRead.getMask().getDefiningOp<vector::ExtractOp>())
     if (auto maskOp =
-            extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
+            extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
       return TransferMask{maskOp,
                           SmallVector<int64_t>(extractOp.getStaticPosition())};
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 85e485c28c74e..8d6e263934fb4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1309,7 +1309,7 @@ LogicalResult
 ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                             ExtractOp::Adaptor adaptor,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
-  auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
+  auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
   if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
       vectorType.getRank()) {
     inferredReturnTypes.push_back(vectorType.getElementType());
@@ -1379,7 +1379,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
 /// Fold the result of chains of ExtractOp in place by simply concatenating the
 /// positions.
 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
-  if (!extractOp.getVector().getDefiningOp<ExtractOp>())
+  if (!extractOp.getSource().getDefiningOp<ExtractOp>())
     return failure();
 
   // TODO: Canonicalization for dynamic position not implemented yet.
@@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   ExtractOp currentOp = extractOp;
   ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
-  while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
+  while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
     // TODO: Canonicalization for dynamic position not implemented yet.
     if (currentOp.hasDynamicPosition())
@@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
     ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }
-  extractOp.setOperand(0, currentOp.getVector());
+  extractOp.setOperand(0, currentOp.getSource());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   std::reverse(globalPosition.begin(), globalPosition.end());
@@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
     return Value();
 
   // If we can't fold (either internal transposition, or nothing to fold), bail.
-  bool nothingToFold = (source == extractOp.getVector());
+  bool nothingToFold = (source == extractOp.getSource());
   if (nothingToFold || !canFold())
     return Value();
 
@@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
   OpBuilder b(extractOp.getContext());
   extractOp.setStaticPosition(
       ArrayRef(extractPosition).take_front(extractedRank));
-  extractOp.getVectorMutable().assign(source);
+  extractOp.getSourceMutable().assign(source);
   return extractOp.getResult();
 }
 
@@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
   if (extractOp.hasDynamicPosition())
     return Value();
 
-  Value valueToExtractFrom = extractOp.getVector();
+  Value valueToExtractFrom = extractOp.getSource();
   updateStateForNextIteration(valueToExtractFrom);
   while (nextInsertOp || nextTransposeOp) {
     // Case 1. If we hit a transpose, just compose the map and iterate.
@@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) {
 /// `extract` shape.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
 
-  Operation *defOp = extractOp.getVector().getDefiningOp();
+  Operation *defOp = extractOp.getSource().getDefiningOp();
   if (!defOp || !isBroadcastLike(defOp))
     return Value();
 
@@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
   if (extractOp.hasDynamicPosition())
     return Value();
 
-  auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
+  auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
   if (!shuffleOp)
     return Value();
 
@@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   if (extractOp.hasDynamicPosition())
     return Value();
 
-  auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+  auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
   if (!shapeCastOp)
     return Value();
 
@@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     return Value();
 
   auto extractStridedSliceOp =
-      extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
+      extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
   if (!extractStridedSliceOp)
     return Value();
 
@@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
   assert(extractedPos.size() >= sliceOffsets.size());
   for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
     extractedPos[i] = extractedPos[i] + sliceOffsets[i];
-  extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
+  extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
 
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
@@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
       llvm::isa<VectorType>(extractOp.getType())
           ? llvm::cast<VectorType>(extractOp.getType()).getRank()
           : 0;
-  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+  auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
   if (!insertOp)
     return Value();
 
@@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                                                     insertRankDiff))
           return Value();
       }
-      extractOp.getVectorMutable().assign(insertOp.getValueToStore());
+      extractOp.getSourceMutable().assign(insertOp.getValueToStore());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(extractOp.getContext());
       extractOp.setStaticPosition(offsetDiffs);
@@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
     return {};
 
   // Look for extract(from_elements).
-  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
   if (!fromElementsOp)
     return {};
 
@@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
   // mismatch).
-  if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
-    return getVector();
-  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+  if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
+    return getSource();
+  if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
     return res;
   // Fold `arith.constant` indices into the `vector.extract` operation.
   // Do not stop here as this fold may enable subsequent folds that require
   // constant indices.
-  SmallVector<Value> operands = {getVector()};
+  SmallVector<Value> operands = {getSource()};
   auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
 
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
-  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
     return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -2187,7 +2187,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
 
-    Operation *defOp = extractOp.getVector().getDefiningOp();
+    Operation *defOp = extractOp.getSource().getDefiningOp();
     VectorType outType = dyn_cast<VectorType>(extractOp.getType());
     if (!defOp || !isBroadcastLike(defOp) || !outType)
       return failure();
@@ -2210,7 +2210,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
     auto createMaskOp =
-        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+        extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
     if (!createMaskOp)
       return failure();
 
@@ -2271,7 +2271,7 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
                                                   PatternRewriter &rewriter) {
-  auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
+  auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
   if (!castOp)
     return failure();
 
@@ -2306,7 +2306,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
     return failure();
 
   // Look for extracts from a from_elements op.
-  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
   if (!fromElementsOp)
     return failure();
   VectorType inputType = fromElementsOp.getType();
@@ -2558,8 +2558,8 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
       // Check condition (i) by checking that all elements have the same source
       // as the first element.
       if (insertIndex == 0) {
-        source = extractOp.getVector();
-      } else if (extractOp.getVector() != source) {
+        source = extractOp.getSource();
+      } else if (extractOp.getSource() != source) {
         return rewriter.notifyMatchFailure(fromElements,
                                            "element from different vector");
       }
@@ -4095,7 +4095,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
   ArrayAttr extractOffsets = op.getOffsets();
   ArrayAttr extractStrides = op.getStrides();
   ArrayAttr extractSizes = op.getSizes();
-  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
+  auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
   while (insertOp) {
     if (op.getSourceVectorType().getRank() !=
         insertOp.getSourceVectorType().getRank())
@@ -4199,17 +4199,17 @@ foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
 
 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
   if (getSourceVectorType() == getResult().getType())
-    return getVector();
+    return getSource();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
 
   // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
   if (auto splat =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
     DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
 
   // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
-  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
+  return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
 }
 
 void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
@@ -4241,7 +4241,7 @@ class StridedSliceCreateMaskFolder final
     // Return if 'extractStridedSliceOp' operand is not defined by a
     // CreateMaskOp.
     auto createMaskOp =
-        extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
+        extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
     if (!createMaskOp)
       return failure();
     // Return if 'extractStridedSliceOp' has non-unit strides.
@@ -4298,7 +4298,7 @@ class StridedSliceConstantMaskFolder final
                                 PatternRewriter &rewriter) const override {
     // Return if 'extractStridedSliceOp' operand is not defined by a
     // ConstantMaskOp.
-    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
     auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
     if (!constantMaskOp)
       return failure();
@@ -4351,7 +4351,7 @@ class StridedSliceBroadcast final
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
-    auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
+    auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
     if (!broadcast)
       return failure();
     auto srcVecType =
@@ -4403,7 +4403,7 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
 
-    Value splat = getScalarSplatSource(op.getVector());
+    Value splat = getScalarSplatSource(op.getSource());
     if (!splat)
       return failure();
     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 995a2595e5fbb..e95338f7d18be 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1341,7 +1341,7 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
         VectorType::get(newDistributedShape, distributedType.getElementType());
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+        rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
@@ -1395,7 +1395,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
       // the 1d case).
       SmallVector<size_t> newRetIndices;
       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-          rewriter, warpOp, {extractOp.getVector()},
+          rewriter, warpOp, {extractOp.getSource()},
           {extractOp.getSourceVectorType()}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1424,7 +1424,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
         VectorType::get(newDistributedShape, distributedType.getElementType());
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+        rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1478,7 +1478,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
       distributedVecType = extractSrcType;
     }
     // Yield source vector and position (if present) from warp op.
-    SmallVector<Value> additionalResults{extractOp.getVector()};
+    SmallVector<Value> additionalResults{extractOp.getSource()};
     SmallVector<Type> additionalResultTypes{distributedVecType};
     additionalResults.append(
         SmallVector<Value>(extractOp.getDynamicPosition()));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 9889d7f221fe6..cab12894487e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -78,7 +78,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     Location loc = extractOp.getLoc();
 
     Value newSrcVector = vector::ExtractOp::create(
-        rewriter, loc, extractOp.getVector(), splatZero(dropCount));
+        rewriter, loc, extractOp.getSource(), splatZero(dropCount));
 
     // The offsets/sizes/strides attribute can have a less number of elements
     // than the input vector's rank: it is meant for the leading dimensions.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..852eff2d2b909 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -94,7 +94,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
          !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
              maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
-      maskOp = extractOp.getVector().getDefiningOp();
+      maskOp = extractOp.getSource().getDefiningOp();
       extractOps.push_back(extractOp);
     }
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index cbb9d4bbf0b1f..f6d6555f4c6e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -213,8 +213,8 @@ class Convert1DExtractStridedSliceIntoShuffle
     for (int64_t off = offset, e = offset + size * stride; off < e;
          off += stride)
       offsets.push_back(off);
-    rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
-                                           op.getVector(), offsets);
+    rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getSource(),
+                                           op.getSource(), offsets);
     return success();
   }
 };
@@ -250,7 +250,7 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
     SmallVector<Value> elements;
     elements.reserve(size);
     for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
-      elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i));
+      elements.push_back(ExtractOp::create(rewriter, loc, op.getSource(), i));
 
     Value result = arith::ConstantOp::create(
         rewriter, loc, rewriter.getZeroAttr(op.getType()));
@@ -306,7 +306,7 @@ class DecomposeNDExtractStridedSlice
     Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
          off += stride, ++idx) {
-      Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
+      Value one = ExtractOp::create(rewriter, loc, op.getSource(), off);
       Value extracted = ExtractStridedSliceOp::create(
           rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
           getI64SubArray(op.getSizes(), /* dropFront=*/1),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 12acf4b3f07f5..82bac8c499028 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -252,7 +252,7 @@ struct LinearizeVectorExtractStridedSlice final
     SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
         outputShape, inputShape, offsets.value());
 
-    Value srcVector = adaptor.getVector();
+    Value srcVector = adaptor.getSource();
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
         extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
     return success();
@@ -438,8 +438,8 @@ struct LinearizeVectorExtract final
       return rewriter.notifyMatchFailure(extractOp,
                                          "dynamic position is not supported.");
 
-    llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
-    int64_t size = extractOp.getVector().getType().getNumElements();
+    llvm::ArrayRef<int64_t> shape = extractOp.getSource().getType().getShape();
+    int64_t size = extractOp.getSource().getType().getNumElements();
 
     // Compute linearized offset.
     int64_t linearizedOffset = 0;
@@ -449,7 +449,7 @@ struct LinearizeVectorExtract final
       linearizedOffset += offsets[i] * size;
     }
 
-    Value srcVector = adaptor.getVector();
+    Value srcVector = adaptor.getSource();
     if (!isa<VectorType>(extractOp.getType())) {
       // Scalar case: generate a 1-D extract.
       Value result = rewriter.createOrFold<vector::ExtractOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1874a3477a16c..c364a8b54167c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -1007,7 +1007,7 @@ class RewriteScalarExtractOfTransferRead
   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
     // Match phase.
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+    auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
     if (!xferOp)
       return failure();
     // Check that we are extracting a scalar and not a sub-vector.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dbb5eb398fae6..866f789ec6a39 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -576,7 +576,7 @@ struct BubbleDownVectorBitCastForExtract
     if (extractOp.getSourceVectorType().getRank() != 1)
       return failure();
 
-    auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
+    auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
     if (!castOp)
       return failure();
 
@@ -647,7 +647,7 @@ struct BubbleDownBitCastForStridedSliceExtract
 
   LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
+    auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
     if (!castOp)
       return failure();
 
@@ -1135,7 +1135,7 @@ class ExtractOpFromElementwise final
 
   LogicalResult matchAndRewrite(vector::ExtractOp op,
                                 PatternRewriter &rewriter) const override {
-    Operation *eltwise = op.getVector().getDefiningOp();
+    Operation *eltwise = op.getSource().getDefiningOp();
 
     // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
     // as it doesn't support scalars.
@@ -1210,7 +1210,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
 
   LogicalResult matchAndRewrite(vector::ExtractOp op,
                                 PatternRewriter &rewriter) const override {
-    auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
+    auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
     if (!loadOp)
       return rewriter.notifyMatchFailure(op, "expected a load op");
 



More information about the Mlir-commits mailing list