[Mlir-commits] [mlir] [mlir][vector] Standardise `valueToStore` Naming Across Vector Ops (NFC) (PR #134206)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Apr 3 04:51:25 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/134206

>From 8fe3975950d58656c0fe7bd20ab051bf8d08cd0c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Apr 2025 17:49:09 +0100
Subject: [PATCH] [mlir][vector] Standardise `valueToStore` Naming Across
 Vector Ops (NFC)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This change standardises the naming convention for the argument
representing the value to store in various vector operations.
Specifically, it ensures that all vector ops storing a value—whether
into memory, a tensor, or another vector — use `valueToStore` for the
corresponding argument name.

Updated operations:
* `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`,
  `vector.insert_strided_slice`.

For reference, here are operations that currently use `valueToStore`:
* `vector.store` `vector.scatter`, `vector.compressstore`,
  `vector.maskedstore`.

This change is non-functional (NFC) and does not affect the
functionality of these operations.

Implements #131602
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 43 +++++++------
 .../mlir/Interfaces/VectorInterfaces.td       | 14 ++--
 .../VectorToArmSME/VectorToArmSME.cpp         |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  7 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  2 +-
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 11 ++--
 .../ArmSME/Transforms/VectorLegalization.cpp  |  4 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  2 +-
 .../Linalg/Transforms/Vectorization.cpp       |  4 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 64 +++++++++++++------
 .../Transforms/SubsetOpInterfaceImpl.cpp      |  2 +-
 .../Vector/Transforms/VectorDistribute.cpp    | 20 +++---
 .../Transforms/VectorDropLeadUnitDim.cpp      |  8 +--
 ...sertExtractStridedSliceRewritePatterns.cpp | 10 +--
 .../Vector/Transforms/VectorLinearize.cpp     |  6 +-
 .../Vector/Transforms/VectorTransforms.cpp    |  6 +-
 16 files changed, 119 insertions(+), 86 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..7fc56b1aa4e7e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -907,7 +907,7 @@ def Vector_InsertOp :
   }];
 
   let arguments = (ins
-    AnyType:$source,
+    AnyType:$valueToStore,
     AnyVectorOfAnyRank:$dest,
     Variadic<Index>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
 
   let builders = [
     // Builder to insert a scalar/rank-0 vector into a rank-0 vector.
-    OpBuilder<(ins "Value":$source, "Value":$dest)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
   let extraClassDeclaration = extraPoisonClassDeclaration # [{
-    Type getSourceType() { return getSource().getType(); }
+    Type getValueToStoreType() { return getValueToStore().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
     }
@@ -946,8 +946,8 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
-    attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
 
 def Vector_ScalableInsertOp :
   Vector_Op<"scalable.insert", [Pure,
-       AllElementTypesMatch<["source", "dest"]>,
+       AllElementTypesMatch<["valueToStore", "dest"]>,
        AllTypesMatch<["dest", "res"]>,
        PredOpTrait<"position is a multiple of the source length.",
         CPred<
           "(getPos() % getSourceVectorType().getNumElements()) == 0"
         >>]>,
-     Arguments<(ins VectorOfRank<[1]>:$source,
+     Arguments<(ins VectorOfRank<[1]>:$valueToStore,
                     ScalableVectorOfRank<[1]>:$dest,
                     I64Attr:$pos)>,
      Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getSource().getType());
+      return ::llvm::cast<VectorType>(getValueToStore().getType());
     }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
+    Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
                I64ArrayAttr:$strides)>,
     Results<(outs AnyVectorOfNonZeroRank:$res)> {
   let summary = "strided_slice operation";
   let description = [{
-    Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
+    Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
     `offsets` integer array attribute, a k-sized `strides` integer array attribute
-    and inserts the k-D source vector as a strided subvector at the proper offset
+    and inserts the k-D valueToStore vector as a strided subvector at the proper offset
     into the n-D destination vector.
 
     At the moment strides must contain only 1s.
 
     Returns an n-D vector that is a copy of the n-D destination vector in which
-    the last k-D dimensions contain the k-D source vector elements strided at
+    the last k-D dimensions contain the k-D valueToStore vector elements strided at
     the proper location as specified by the offsets.
 
     Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "Value":$dest,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
       "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
   ];
   let extraClassDeclaration = [{
+    // TODO: Rename
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getSource().getType());
+      return ::llvm::cast<VectorType>(getValueToStore().getType());
     }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector,
+    Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
                    AnyShaped:$source,
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index be939bad14b7b..8ea9d925b3790 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodName=*/"getVector",
       /*args=*/(ins)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the type of the vector that this operation operates on.
+      }],
+      /*retTy=*/"::mlir::VectorType",
+      /*methodName=*/"getVectorType",
+      /*args=*/(ins)
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodName=*/"getIndices",
       /*args=*/(ins)
     >,
+
     InterfaceMethod<
       /*desc=*/[{
         Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       return $_op.getPermutationMap().getNumResults();
     }
 
-    /// Return the type of the vector that this operation operates on.
-    ::mlir::VectorType getVectorType() {
-      return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
-    }
-
     /// Return "true" if at least one of the vector dimensions is a broadcasted
     /// dimension.
     bool hasBroadcastDim() {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4be0fffe8b728..58b85bc0ea6ac 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
     auto loc = insertOp.getLoc();
     auto position = insertOp.getMixedPosition();
 
-    Value source = insertOp.getSource();
+    Value source = insertOp.getValueToStore();
 
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f7375b8d13..847e7e2beebe9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
     // We are going to mutate this 1D vector until it is either the final
     // result (in the non-aggregate case) or the value that needs to be
     // inserted into the aggregate result.
-    Value sourceAggregate = adaptor.getSource();
+    Value sourceAggregate = adaptor.getValueToStore();
     if (insertIntoInnermostDim) {
       // Scalar-into-1D-vector case, so we know we will have to create a
       // InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
       }
       // Insert the scalar into the 1D vector.
       sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
-          loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
+          loc, sourceAggregate.getType(), sourceAggregate,
+          adaptor.getValueToStore(),
           getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
     }
 
@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
-        insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
+        insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..b9b598c02b4a2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
                                      buffers.dataBuffer);
     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
     rewriter.modifyOpInPlace(xferOp, [&]() {
-      xferOp.getVectorMutable().assign(loadedVec);
+      xferOp.getValueToStoreMutable().assign(loadedVec);
       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
     });
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index bca77ba68fbd1..de2af69eba9ec 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (isa<VectorType>(insertOp.getSourceType()))
+    if (isa<VectorType>(insertOp.getValueToStoreType()))
       return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
     if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
       return rewriter.notifyMatchFailure(insertOp,
                                          "unsupported dest vector type");
 
     // Special case for inserting scalar values into size-1 vectors.
-    if (insertOp.getSourceType().isIntOrFloat() &&
+    if (insertOp.getValueToStoreType().isIntOrFloat() &&
         insertOp.getDestVectorType().getNumElements() == 1) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
+      rewriter.replaceOp(insertOp, adaptor.getValueToStore());
       return success();
     }
 
@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
             insertOp,
             "Static use of poison index handled elsewhere (folded to poison)");
       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+          insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
           vector::InsertOp::kPoisonIndex,
           insertOp.getDestVectorType().getNumElements());
       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-          insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+          insertOp, insertOp.getDest(), adaptor.getValueToStore(),
+          sanitizedIndex);
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..62a148d2b7e62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
 
     auto loc = writeOp.getLoc();
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
-    auto inputSMETiles = adaptor.getVector();
+    auto inputSMETiles = adaptor.getValueToStore();
 
     Value destTensorOrMemref = writeOp.getSource();
     for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
     rewriter.setInsertionPointToStart(storeLoop.getBody());
 
     // For each sub-tile of the multi-tile `vectorType`.
-    auto inputSMETiles = adaptor.getVector();
+    auto inputSMETiles = adaptor.getValueToStore();
     auto tileSliceIndex = storeLoop.getInductionVar();
     for (auto [index, smeTile] : llvm::enumerate(
              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index acfd9683f01f4..20e4e3cee7ed4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       if (failed(maybeNewLoop))
         return WalkResult::interrupt();
 
-      transferWrite.getVectorMutable().assign(
+      transferWrite.getValueToStoreMutable().assign(
           maybeNewLoop->getOperation()->getResults().back());
       changed = true;
       // Need to interrupt and restart because erasing the loop messes up
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c8b1b85ef5a3..5afe378463d13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
   rewriter.create<vector::TransferWriteOp>(
       xferOp.getLoc(), vector, out, xferOp.getIndices(),
       xferOp.getPermutationMapAttr(), xferOp.getMask(),
-      rewriter.getBoolArrayAttr(
-          SmallVector<bool>(vector.getType().getRank(), false)));
+      rewriter.getBoolArrayAttr(SmallVector<bool>(
+          dyn_cast<VectorType>(vector.getType()).getRank(), false)));
 
   rewriter.eraseOp(copyOp);
   rewriter.eraseOp(xferOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..98d98f067de14 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1555,7 +1555,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
   if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
     return failure();
   // Case 2.a. early-exit fold.
-  res = nextInsertOp.getSource();
+  res = nextInsertOp.getValueToStore();
   // Case 2.b. if internal transposition is present, canFold will be false.
   return success(canFold());
 }
@@ -1579,7 +1579,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
                         extractPosition.begin() + insertedPos.size());
   extractedRank = extractPosition.size() - sentinels.size();
   // Case 3.a. early-exit fold (break and delegate to post-while path).
-  res = nextInsertOp.getSource();
+  res = nextInsertOp.getValueToStore();
   // Case 3.b. if internal transposition is present, canFold will be false.
   return success();
 }
@@ -1936,7 +1936,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                                                     insertRankDiff))
           return Value();
       }
-      extractOp.getVectorMutable().assign(insertOp.getSource());
+      extractOp.getVectorMutable().assign(insertOp.getValueToStore());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(extractOp.getContext());
       extractOp.setStaticPosition(offsetDiffs);
@@ -2958,7 +2958,7 @@ LogicalResult InsertOp::verify() {
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
         "expected position attribute of rank no greater than dest vector rank");
-  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
+  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
   if (srcVectorType &&
       (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
        static_cast<unsigned>(destVectorType.getRank())))
@@ -2994,12 +2994,13 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
 
   LogicalResult matchAndRewrite(InsertOp insertOp,
                                 PatternRewriter &rewriter) const override {
-    auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
+    auto srcVecType =
+        llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
                            srcVecType.getNumElements())
       return failure();
     rewriter.replaceOpWithNewOp<BroadcastOp>(
-        insertOp, insertOp.getDestVectorType(), insertOp.getSource());
+        insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
     return success();
   }
 };
@@ -3011,7 +3012,7 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
 
   LogicalResult matchAndRewrite(InsertOp op,
                                 PatternRewriter &rewriter) const override {
-    auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+    auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
     auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
 
     if (!srcSplat || !dstSplat)
@@ -3100,17 +3101,17 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
-  if (getNumIndices() == 0 && getSourceType() == getType())
-    return getSource();
-  SmallVector<Value> operands = {getSource(), getDest()};
+  if (getNumIndices() == 0 && getValueToStoreType() == getType())
+    return getValueToStore();
+  SmallVector<Value> operands = {getValueToStore(), getDest()};
   if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
     return val;
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
-  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
-                                                   adaptor.getDest(),
-                                                   vectorSizeFoldThreshold)) {
+  if (auto res = foldDenseElementsAttrDestInsertOp(
+          *this, adaptor.getValueToStore(), adaptor.getDest(),
+          vectorSizeFoldThreshold)) {
     return res;
   }
 
@@ -3291,7 +3292,7 @@ class FoldInsertStridedSliceSplat final
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
     auto srcSplatOp =
-        insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+        insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
     auto destSplatOp =
         insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
 
@@ -3316,7 +3317,7 @@ class FoldInsertStridedSliceOfExtract final
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
     auto extractStridedSliceOp =
-        insertStridedSliceOp.getSource()
+        insertStridedSliceOp.getValueToStore()
             .getDefiningOp<vector::ExtractStridedSliceOp>();
 
     if (!extractStridedSliceOp)
@@ -3365,7 +3366,7 @@ class InsertStridedSliceConstantFolder final
         !destVector.hasOneUse())
       return failure();
 
-    TypedValue<VectorType> sourceValue = op.getSource();
+    TypedValue<VectorType> sourceValue = op.getValueToStore();
     Attribute sourceCst;
     if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
       return failure();
@@ -3425,7 +3426,7 @@ void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
 
 OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
   if (getSourceVectorType() == getDestVectorType())
-    return getSource();
+    return getValueToStore();
   return {};
 }
 
@@ -3691,7 +3692,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
     }
     // The extract element chunk is a subset of the insert element.
     if (!disjoint && !patialoverlap) {
-      op.setOperand(insertOp.getSource());
+      op.setOperand(insertOp.getValueToStore());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(op.getContext());
       op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
@@ -4349,6 +4350,13 @@ Type TransferReadOp::getExpectedMaskType() {
   return inferTransferOpMaskType(getVectorType(), getPermutationMap());
 }
 
+//===----------------------------------------------------------------------===//
+// TransferReadOp: VectorTransferOpInterface methods.
+//===----------------------------------------------------------------------===//
+VectorType TransferReadOp::getVectorType() {
+  return cast<VectorType>(getVector().getType());
+}
+
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
@@ -4739,7 +4747,9 @@ LogicalResult TransferWriteOp::verify() {
                               [&](Twine t) { return emitOpError(t); });
 }
 
-// MaskableOpInterface methods.
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: MaskableOpInterface methods.
+//===----------------------------------------------------------------------===//
 
 /// Returns the mask type expected by this operation. Mostly used for
 /// verification purposes.
@@ -4747,6 +4757,17 @@ Type TransferWriteOp::getExpectedMaskType() {
   return inferTransferOpMaskType(getVectorType(), getPermutationMap());
 }
 
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: VectorTransferOpInterface methods.
+//===----------------------------------------------------------------------===//
+Value TransferWriteOp::getVector() { return getOperand(0); }
+VectorType TransferWriteOp::getVectorType() {
+  return cast<VectorType>(getValueToStore().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: fold methods.
+//===----------------------------------------------------------------------===//
 /// Fold:
 /// ```
 ///    %t1 = ...
@@ -4863,6 +4884,9 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
   return memref::foldMemRefCast(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// TransferWriteOp: other methods.
+//===----------------------------------------------------------------------===//
 std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
@@ -4871,7 +4895,7 @@ void TransferWriteOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
   if (llvm::isa<MemRefType>(getShapedType()))
-    effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
+    effects.emplace_back(MemoryEffects::Write::get(), &getValueToStoreMutable(),
                          SideEffects::DefaultResource::get());
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
index b450d5b78a466..e8e178fe75962 100644
--- a/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp
@@ -45,7 +45,7 @@ struct TransferWriteOpSubsetInsertionOpInterface
     : public SubsetInsertionOpInterface::ExternalModel<
           TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
   OpOperand &getSourceOperand(Operation *op) const {
-    return cast<vector::TransferWriteOp>(op).getVectorMutable();
+    return cast<vector::TransferWriteOp>(op).getValueToStoreMutable();
   }
 
   OpOperand &getDestinationOperand(Operation *op) const {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..19f408ad1b570 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -496,7 +496,8 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
     rewriter.setInsertionPointToStart(&body);
     auto newWriteOp =
         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
-    newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+    newWriteOp.getValueToStoreMutable().assign(
+        newWarpOp.getResult(newRetIndices[0]));
     rewriter.eraseOp(writeOp);
     rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
     return success();
@@ -559,7 +560,8 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
     auto newWriteOp =
         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
     rewriter.eraseOp(writeOp);
-    newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+    newWriteOp.getValueToStoreMutable().assign(
+        newWarpOp.getResult(newRetIndices[0]));
     if (maybeMaskType)
       newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
     return newWriteOp;
@@ -1299,9 +1301,9 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
 
     // Yield destination vector, source scalar and position from warp op.
     SmallVector<Value> additionalResults{insertOp.getDest(),
-                                         insertOp.getSource()};
-    SmallVector<Type> additionalResultTypes{distrType,
-                                            insertOp.getSource().getType()};
+                                         insertOp.getValueToStore()};
+    SmallVector<Type> additionalResultTypes{
+        distrType, insertOp.getValueToStore().getType()};
     additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
     additionalResultTypes.append(
         SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
@@ -1393,8 +1395,8 @@ struct WarpOpInsert : public WarpDistributionPattern {
       // out of the warp op.
       SmallVector<size_t> newRetIndices;
       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-          rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
-          {insertOp.getSourceType(), insertOp.getDestVectorType()},
+          rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+          {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
           newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
       Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
@@ -1422,7 +1424,7 @@ struct WarpOpInsert : public WarpDistributionPattern {
     assert(distrDestDim != -1 && "could not find distributed dimension");
 
     // Compute the distributed source vector type.
-    VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
+    VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
     SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
     // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
     // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
@@ -1439,7 +1441,7 @@ struct WarpOpInsert : public WarpDistributionPattern {
     // Yield source and dest vectors from warp op.
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
         {distrSrcType, distrDestType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index b53aa997c9014..fda3baf3aa390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -122,7 +122,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     Location loc = insertOp.getLoc();
 
     Value newSrcVector = rewriter.create<vector::ExtractOp>(
-        loc, insertOp.getSource(), splatZero(srcDropCount));
+        loc, insertOp.getValueToStore(), splatZero(srcDropCount));
     Value newDstVector = rewriter.create<vector::ExtractOp>(
         loc, insertOp.getDest(), splatZero(dstDropCount));
 
@@ -148,7 +148,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
 
   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
                                 PatternRewriter &rewriter) const override {
-    Type oldSrcType = insertOp.getSourceType();
+    Type oldSrcType = insertOp.getValueToStoreType();
     Type newSrcType = oldSrcType;
     int64_t oldSrcRank = 0, newSrcRank = 0;
     if (auto type = dyn_cast<VectorType>(oldSrcType)) {
@@ -168,10 +168,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     // Trim leading one dimensions from both operands.
     Location loc = insertOp.getLoc();
 
-    Value newSrcVector = insertOp.getSource();
+    Value newSrcVector = insertOp.getValueToStore();
     if (oldSrcRank != 0) {
       newSrcVector = rewriter.create<vector::ExtractOp>(
-          loc, insertOp.getSource(), splatZero(srcDropCount));
+          loc, insertOp.getValueToStore(), splatZero(srcDropCount));
     }
     Value newDstVector = rewriter.create<vector::ExtractOp>(
         loc, insertOp.getDest(), splatZero(dstDropCount));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 82a985c9e5824..d834a99076834 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -61,7 +61,7 @@ class DecomposeDifferentRankInsertStridedSlice
     // A different pattern will kick in for InsertStridedSlice with matching
     // ranks.
     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
-        loc, op.getSource(), extracted,
+        loc, op.getValueToStore(), extracted,
         getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
         getI64SubArray(op.getStrides(), /*dropFront=*/0));
 
@@ -111,7 +111,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
       return failure();
 
     if (srcType == dstType) {
-      rewriter.replaceOp(op, op.getSource());
+      rewriter.replaceOp(op, op.getValueToStore());
       return success();
     }
 
@@ -131,8 +131,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
       SmallVector<int64_t> offsets(nDest, 0);
       for (int64_t i = 0; i < nSrc; ++i)
         offsets[i] = i;
-      Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
-                                                      op.getSource(), offsets);
+      Value scaledSource = rewriter.create<ShuffleOp>(
+          loc, op.getValueToStore(), op.getValueToStore(), offsets);
 
       // 2. Create a mask where we take the value from scaledSource of dest
       // depending on the offset.
@@ -156,7 +156,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
          off += stride, ++idx) {
       // 1. extract the proper subvector (or element) from source
       Value extractedSource =
-          rewriter.create<ExtractOp>(loc, op.getSource(), idx);
+          rewriter.create<ExtractOp>(loc, op.getValueToStore(), idx);
       if (isa<VectorType>(extractedSource.getType())) {
         // 2. If we have a vector, extract the proper subvector from destination
         // Otherwise we are at the element level and no need to recurse.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 9dccc005322eb..a009aa03aaf64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -439,7 +439,7 @@ struct LinearizeVectorInsert final
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalable vectors are not supported.");
 
-    if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+    if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
                                          targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           insertOp, "Can't flatten since targetBitWidth < OpSize");
@@ -448,7 +448,7 @@ struct LinearizeVectorInsert final
     if (insertOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(insertOp,
                                          "dynamic position is not supported.");
-    auto srcTy = insertOp.getSourceType();
+    auto srcTy = insertOp.getValueToStoreType();
     auto srcAsVec = dyn_cast<VectorType>(srcTy);
     uint64_t srcSize = 0;
     if (srcAsVec) {
@@ -484,7 +484,7 @@ struct LinearizeVectorInsert final
                                            // [offset+srcNumElements, end)
 
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
+        insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
 
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b6fac80d871e6..d50d5fe96f49a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -748,7 +748,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
       return failure();
 
     // Only vector sources are supported for now.
-    auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+    auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
     if (!insertSrcType)
       return failure();
 
@@ -759,7 +759,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
     VectorType newCastSrcType =
         VectorType::get(srcDims, castDstType.getElementType());
     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
-        bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+        bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
 
     SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
     dstDims.back() =
@@ -850,7 +850,7 @@ struct BubbleUpBitCastForStridedSliceInsert
         VectorType::get(srcDims, castDstType.getElementType());
 
     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
-        bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+        bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
 
     SmallVector<int64_t> dstDims =
         llvm::to_vector<4>(insertOp.getDestVectorType().getShape());



More information about the Mlir-commits mailing list