[Mlir-commits] [mlir] [mlir][Vector] Remove usage of `vector.insertelement/extractelement` from Vector (PR #144413)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 16 12:19:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It removes instances of `vector.extractelement` and `vector.insertelement` from the Vector dialect layer.

---

Patch is 29.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144413.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+2-2) 
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+5-8) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+5-46) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+25-94) 
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+7-7) 
- (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+3-3) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+24-24) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..ec0f856cb3f5a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -233,8 +233,8 @@ void populateBreakDownVectorReductionPatterns(
 ///
 /// [DecomposeNDExtractStridedSlice]
 /// ================================
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..45059f19a95c4 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -760,8 +760,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
 
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
-      // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // non-constant value.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -824,8 +823,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1567,7 +1565,7 @@ struct Strategy1d<TransferReadOp> {
         /*inBoundsCase=*/
         [&](OpBuilder &b, Location loc) {
           Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1595,8 +1593,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
         });
     b.create<scf::YieldOp>(loc);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..90970ae53defc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1255,27 +1255,6 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
-/// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement : public WarpDistributionPattern {
-  using Base::Base;
-  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
-    if (!operand)
-      return failure();
-    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    SmallVector<OpFoldResult> indices;
-    if (auto pos = extractOp.getPosition()) {
-      indices.push_back(pos);
-    }
-    rewriter.setInsertionPoint(extractOp);
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, extractOp.getVector(), indices);
-    return success();
-  }
-};
-
 /// Pattern to move out vector.insert with a scalar input.
 /// Only supports 1-D and 0-D destinations for now.
 struct WarpOpInsertScalar : public WarpDistributionPattern {
@@ -1483,26 +1462,6 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
-struct WarpOpInsertElement : public WarpDistributionPattern {
-  using Base::Base;
-  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
-    if (!operand)
-      return failure();
-    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
-    SmallVector<OpFoldResult> indices;
-    if (auto pos = insertOp.getPosition()) {
-      indices.push_back(pos);
-    }
-    rewriter.setInsertionPoint(insertOp);
-    rewriter.replaceOpWithNewOp<vector::InsertOp>(
-        insertOp, insertOp.getSource(), insertOp.getDest(), indices);
-    return success();
-  }
-};
-
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't
 /// change the order of execution. This creates a new scf.for region after the
@@ -1761,11 +1720,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..62e7f7cc61f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
   unsigned targetVectorBitwidth;
 };
 
-/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
-/// to `memref.load` patterns. The `match` method is shared for both
-/// `vector.extract` and `vector.extract_element`.
-template <class VectorExtractOp>
-class RewriteScalarExtractOfTransferReadBase
-    : public OpRewritePattern<VectorExtractOp> {
-  using Base = OpRewritePattern<VectorExtractOp>;
-
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be `vector.extract` ops. If
+/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
+/// users. Otherwise, rewrite only if the extract op is the single user of the
+/// transfer op. Rewriting a single vector load with multiple scalar loads may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+    : public OpRewritePattern<vector::ExtractOp> {
 public:
-  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
-                                         PatternBenefit benefit,
-                                         bool allowMultipleUses)
-      : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
-
-  LogicalResult match(VectorExtractOp extractOp) const {
-    auto xferOp =
-        extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
+  RewriteScalarExtractOfTransferRead(MLIRContext *context,
+                                     PatternBenefit benefit,
+                                     bool allowMultipleUses)
+      : OpRewritePattern(context, benefit),
+        allowMultipleUses(allowMultipleUses) {}
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Match phase.
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     if (!xferOp)
       return failure();
     // Check that we are extracting a scalar and not a sub-vector.
@@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
     // If multiple uses are allowed, check if all the xfer uses are extract ops.
     if (allowMultipleUses &&
         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
-          return isa<vector::ExtractOp, vector::ExtractElementOp>(
-              use.getOwner());
+          return isa<vector::ExtractOp>(use.getOwner());
         }))
       return failure();
     // Mask not supported.
@@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
     // Cannot rewrite if the indices may be out of bounds.
     if (xferOp.hasOutOfBoundsDim())
       return failure();
-    return success();
-  }
-
-private:
-  bool allowMultipleUses;
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
-    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
-  using RewriteScalarExtractOfTransferReadBase::
-      RewriteScalarExtractOfTransferReadBase;
-
-  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(match(extractOp)))
-      return failure();
-
-    // Construct scalar load.
-    auto loc = extractOp.getLoc();
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
-    SmallVector<Value> newIndices(xferOp.getIndices().begin(),
-                                  xferOp.getIndices().end());
-    if (extractOp.getPosition()) {
-      AffineExpr sym0, sym1;
-      bindSymbols(extractOp.getContext(), sym0, sym1);
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, sym0 + sym1,
-          {newIndices[newIndices.size() - 1], extractOp.getPosition()});
-      if (auto value = dyn_cast<Value>(ofr)) {
-        newIndices[newIndices.size() - 1] = value;
-      } else {
-        newIndices[newIndices.size() - 1] =
-            rewriter.create<arith::ConstantIndexOp>(loc,
-                                                    *getConstantIntValue(ofr));
-      }
-    }
-    if (isa<MemRefType>(xferOp.getBase().getType())) {
-      rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
-                                                  newIndices);
-    } else {
-      rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
-          extractOp, xferOp.getBase(), newIndices);
-    }
-
-    return success();
-  }
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractOfTransferRead
-    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
-  using RewriteScalarExtractOfTransferReadBase::
-      RewriteScalarExtractOfTransferReadBase;
-
-  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(match(extractOp)))
-      return failure();
 
-    // Construct scalar load.
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+    // Rewrite phase: construct scalar load.
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
@@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead
 
     return success();
   }
+
+private:
+  bool allowMultipleUses;
 };
 
 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     bool allowMultipleUses) {
-  patterns.add<RewriteScalarExtractElementOfTransferRead,
-               RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+  patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
                                                    benefit, allowMultipleUses);
   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a..33177736eb5fe 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
       // Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
       // CHECK: scf.if
       // CHECK-NEXT: memref.load
-      // CHECK-NEXT: vector.insertelement
+      // CHECK-NEXT: vector.insert
       // CHECK-NEXT: scf.yield
       // CHECK-NEXT: else
       // CHECK-NEXT: scf.yield
@@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK:                       %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
   // CHECK:                       scf.if {{.*}} -> (vector<3xf32>) {
   // CHECK-NEXT:                    %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
-  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
+  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
   // CHECK-NEXT:                    scf.yield
   // CHECK-NEXT:                  } else {
   // CHECK-NEXT:                    scf.yield
@@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
 // CHECK:           scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
-// CHECK:             %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
+// CHECK:             %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
 // CHECK:             scf.if %[[MASK_VAL]] {
-// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
+// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
 // CHECK:               memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
 // CHECK:             } else {
 // CHECK:             }
@@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
 // CHECK:             vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
@@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 // CHECK:             scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
 // CHECK:               %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
 // CHECK:               %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
-// CHECK:               %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
+// CHECK:               %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[FLAT_INDEX]]] : f32 from vector<4xf32>
 // CHECK:               vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:               %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
 // CHECK:               scf.if %[[IS_NOT_LAST_J]] {
@@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
 // CHECK:           %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[VEC]][%[[IDX]]] : i32 from vector<[4]xi32>
 // CHECK:             vector.print %[[EL]] : i32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7a1d6b3a8344a..7fec1c6ba5642 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -8,7 +8,7 @@
 func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
   %cst = arith.constant 0.0 : f32
   %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
-  %1 = vector.extractelement %0[] : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
   return %1 : f32
 }
 
@@ -24,7 +24,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
   %cst = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
   %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
-  %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+  %1 = vector.extract %0[%idx2] : f32 from vector<5xf32>
   return %1 : f32
 }
 
@@ -37,7 +37,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
 func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
   %cst = arith.constant 0.0 : f32
   %0 = vector.transfer_read %t[%idx, %idx, %idx], %...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144413


More information about the Mlir-commits mailing list