[Mlir-commits] [mlir] [mlir][vector] Standardise `valueToStore` Naming Across Vector Ops (NFC) (PR #134206)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 2 23:51:51 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

<details>
<summary>Changes</summary>

This change standardises the naming convention for the argument
representing the value to store in various vector operations.
Specifically, it ensures that all vector ops storing a value—whether
into memory, a tensor, or another vector — use `valueToStore` for the
corresponding argument name.

Updated operations:
* `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`,
  `vector.insert_strided_slice`.

For reference, here are operations that currently use `valueToStore`:
* `vector.store` `vector.scatter`, `vector.compressstore`,
  `vector.maskedstore`.

This change is non-functional (NFC) and does not affect the
functionality of these operations.

Implements #<!-- -->131602


---

Patch is 36.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134206.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+22-21) 
- (modified) mlir/include/mlir/Interfaces/VectorInterfaces.td (+9-5) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+1-1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+4-3) 
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1) 
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+6-5) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+44-20) 
- (modified) mlir/lib/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+11-9) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+4-4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+5-5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 134472cefbf4e..7fc56b1aa4e7e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -907,7 +907,7 @@ def Vector_InsertOp :
   }];
 
   let arguments = (ins
-    AnyType:$source,
+    AnyType:$valueToStore,
     AnyVectorOfAnyRank:$dest,
     Variadic<Index>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
 
   let builders = [
     // Builder to insert a scalar/rank-0 vector into a rank-0 vector.
-    OpBuilder<(ins "Value":$source, "Value":$dest)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
-    OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
   ];
 
   let extraClassDeclaration = extraPoisonClassDeclaration # [{
-    Type getSourceType() { return getSource().getType(); }
+    Type getValueToStoreType() { return getValueToStore().getType(); }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
     }
@@ -946,8 +946,8 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
-    attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
 
 def Vector_ScalableInsertOp :
   Vector_Op<"scalable.insert", [Pure,
-       AllElementTypesMatch<["source", "dest"]>,
+       AllElementTypesMatch<["valueToStore", "dest"]>,
        AllTypesMatch<["dest", "res"]>,
        PredOpTrait<"position is a multiple of the source length.",
         CPred<
           "(getPos() % getSourceVectorType().getNumElements()) == 0"
         >>]>,
-     Arguments<(ins VectorOfRank<[1]>:$source,
+     Arguments<(ins VectorOfRank<[1]>:$valueToStore,
                     ScalableVectorOfRank<[1]>:$dest,
                     I64Attr:$pos)>,
      Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let extraClassDeclaration = extraPoisonClassDeclaration # [{
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getSource().getType());
+      return ::llvm::cast<VectorType>(getValueToStore().getType());
     }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
+    Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
                I64ArrayAttr:$strides)>,
     Results<(outs AnyVectorOfNonZeroRank:$res)> {
   let summary = "strided_slice operation";
   let description = [{
-    Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
+    Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
     `offsets` integer array attribute, a k-sized `strides` integer array attribute
-    and inserts the k-D source vector as a strided subvector at the proper offset
+    and inserts the k-D valueToStore vector as a strided subvector at the proper offset
     into the n-D destination vector.
 
     At the moment strides must contain only 1s.
 
     Returns an n-D vector that is a copy of the n-D destination vector in which
-    the last k-D dimensions contain the k-D source vector elements strided at
+    the last k-D dimensions contain the k-D valueToStore vector elements strided at
     the proper location as specified by the offsets.
 
     Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "Value":$dest,
+    OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
       "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
   ];
   let extraClassDeclaration = [{
+    // TODO: Rename
     VectorType getSourceVectorType() {
-      return ::llvm::cast<VectorType>(getSource().getType());
+      return ::llvm::cast<VectorType>(getValueToStore().getType());
     }
     VectorType getDestVectorType() {
       return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector,
+    Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
                    AnyShaped:$source,
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index be939bad14b7b..8ea9d925b3790 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodName=*/"getVector",
       /*args=*/(ins)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the type of the vector that this operation operates on.
+      }],
+      /*retTy=*/"::mlir::VectorType",
+      /*methodName=*/"getVectorType",
+      /*args=*/(ins)
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       /*methodName=*/"getIndices",
       /*args=*/(ins)
     >,
+
     InterfaceMethod<
       /*desc=*/[{
         Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       return $_op.getPermutationMap().getNumResults();
     }
 
-    /// Return the type of the vector that this operation operates on.
-    ::mlir::VectorType getVectorType() {
-      return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
-    }
-
     /// Return "true" if at least one of the vector dimensions is a broadcasted
     /// dimension.
     bool hasBroadcastDim() {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4be0fffe8b728..58b85bc0ea6ac 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
     auto loc = insertOp.getLoc();
     auto position = insertOp.getMixedPosition();
 
-    Value source = insertOp.getSource();
+    Value source = insertOp.getValueToStore();
 
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f7375b8d13..847e7e2beebe9 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
     // We are going to mutate this 1D vector until it is either the final
     // result (in the non-aggregate case) or the value that needs to be
     // inserted into the aggregate result.
-    Value sourceAggregate = adaptor.getSource();
+    Value sourceAggregate = adaptor.getValueToStore();
     if (insertIntoInnermostDim) {
       // Scalar-into-1D-vector case, so we know we will have to create a
       // InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
       }
       // Insert the scalar into the 1D vector.
       sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
-          loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
+          loc, sourceAggregate.getType(), sourceAggregate,
+          adaptor.getValueToStore(),
           getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
     }
 
@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
-        insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
+        insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..b9b598c02b4a2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
                                      buffers.dataBuffer);
     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
     rewriter.modifyOpInPlace(xferOp, [&]() {
-      xferOp.getVectorMutable().assign(loadedVec);
+      xferOp.getValueToStoreMutable().assign(loadedVec);
       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
     });
 
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index bca77ba68fbd1..de2af69eba9ec 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (isa<VectorType>(insertOp.getSourceType()))
+    if (isa<VectorType>(insertOp.getValueToStoreType()))
       return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
     if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
       return rewriter.notifyMatchFailure(insertOp,
                                          "unsupported dest vector type");
 
     // Special case for inserting scalar values into size-1 vectors.
-    if (insertOp.getSourceType().isIntOrFloat() &&
+    if (insertOp.getValueToStoreType().isIntOrFloat() &&
         insertOp.getDestVectorType().getNumElements() == 1) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
+      rewriter.replaceOp(insertOp, adaptor.getValueToStore());
       return success();
     }
 
@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
             insertOp,
             "Static use of poison index handled elsewhere (folded to poison)");
       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+          insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
     } else {
       Value sanitizedIndex = sanitizeDynamicIndex(
           rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
           vector::InsertOp::kPoisonIndex,
           insertOp.getDestVectorType().getNumElements());
       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-          insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+          insertOp, insertOp.getDest(), adaptor.getValueToStore(),
+          sanitizedIndex);
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..62a148d2b7e62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
 
     auto loc = writeOp.getLoc();
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
-    auto inputSMETiles = adaptor.getVector();
+    auto inputSMETiles = adaptor.getValueToStore();
 
     Value destTensorOrMemref = writeOp.getSource();
     for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
     rewriter.setInsertionPointToStart(storeLoop.getBody());
 
     // For each sub-tile of the multi-tile `vectorType`.
-    auto inputSMETiles = adaptor.getVector();
+    auto inputSMETiles = adaptor.getValueToStore();
     auto tileSliceIndex = storeLoop.getInductionVar();
     for (auto [index, smeTile] : llvm::enumerate(
              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index acfd9683f01f4..20e4e3cee7ed4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       if (failed(maybeNewLoop))
         return WalkResult::interrupt();
 
-      transferWrite.getVectorMutable().assign(
+      transferWrite.getValueToStoreMutable().assign(
           maybeNewLoop->getOperation()->getResults().back());
       changed = true;
       // Need to interrupt and restart because erasing the loop messes up
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c8b1b85ef5a3..5afe378463d13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
   rewriter.create<vector::TransferWriteOp>(
       xferOp.getLoc(), vector, out, xferOp.getIndices(),
       xferOp.getPermutationMapAttr(), xferOp.getMask(),
-      rewriter.getBoolArrayAttr(
-          SmallVector<bool>(vector.getType().getRank(), false)));
+      rewriter.getBoolArrayAttr(SmallVector<bool>(
+          dyn_cast<VectorType>(vector.getType()).getRank(), false)));
 
   rewriter.eraseOp(copyOp);
   rewriter.eraseOp(xferOp);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..98d98f067de14 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1555,7 +1555,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
   if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
     return failure();
   // Case 2.a. early-exit fold.
-  res = nextInsertOp.getSource();
+  res = nextInsertOp.getValueToStore();
   // Case 2.b. if internal transposition is present, canFold will be false.
   return success(canFold());
 }
@@ -1579,7 +1579,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
                         extractPosition.begin() + insertedPos.size());
   extractedRank = extractPosition.size() - sentinels.size();
   // Case 3.a. early-exit fold (break and delegate to post-while path).
-  res = nextInsertOp.getSource();
+  res = nextInsertOp.getValueToStore();
   // Case 3.b. if internal transposition is present, canFold will be false.
   return success();
 }
@@ -1936,7 +1936,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                                                     insertRankDiff))
           return Value();
       }
-      extractOp.getVectorMutable().assign(insertOp.getSource());
+      extractOp.getVectorMutable().assign(insertOp.getValueToStore());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(extractOp.getContext());
       extractOp.setStaticPosition(offsetDiffs);
@@ -2958,7 +2958,7 @@ LogicalResult InsertOp::verify() {
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
         "expected position attribute of rank no greater than dest vector rank");
-  auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
+  auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
   if (srcVectorType &&
       (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
        static_cast<unsigned>(destVectorType.getRank())))
@@ -2994,12 +2994,13 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
 
   LogicalResult matchAndRewrite(InsertOp insertOp,
                                 PatternRewriter &rewriter) const override {
-    auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
+    auto srcVecType =
+        llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
     if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
                            srcVecType.getNumElements())
       return failure();
     rewriter.replaceOpWithNewOp<BroadcastOp>(
-        insertOp, insertOp.getDestVectorType(), insertOp.getSource());
+        insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
     return success();
   }
 };
@@ -3011,7 +3012,7 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
 
   LogicalResult matchAndRewrite(InsertOp op,
                                 PatternRewriter &rewriter) const override {
-    auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+    auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
     auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
 
     if (!srcSplat || !dstSplat)
@@ -3100,17 +3101,17 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
-  if (getNumIndices() == 0 && getSourceType() == getType())
-    return getSource();
-  SmallVector<Value> operands = {getSource(), getDest()};
+  if (getNumIndices() == 0 && getValueToStoreType() == getType())
+    return getValueToStore();
+  SmallVector<Value> operands = {getValueToStore(), getDest()};
   if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
     return val;
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
-  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
-                                                   adaptor.getDest(),
-                                                   vectorSizeFoldThreshold)) {
+  if (auto res = foldDenseElementsAttrDestInsertOp(
+          *this, adaptor.getValueToStore(), adaptor.getDest(),
+          vectorSizeFoldThreshold)) {
     return res;
   }
 
@@ -3291,7 +3292,7 @@ class FoldInsertStridedSliceSplat final
   LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
                                 PatternRewriter &rewriter) const override {
     auto srcSplatOp =
-        insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+        insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
     auto destSplatOp =
         insertStridedSliceOp.getDes...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/134206


More information about the Mlir-commits mailing list