[Mlir-commits] [mlir] [mlir][Vector] Remove usage of `vector.insertelement/extractelement` from Vector (PR #144413)

Diego Caballero llvmlistbot at llvm.org
Mon Jun 16 12:19:21 PDT 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/144413

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It removes instances of `vector.extractelement` and `vector.insertelement` from the Vector dialect layer.

>From 75ae4a467d6282cd7098df66c0587bccfb885c62 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Sat, 7 Jun 2025 13:33:42 +0000
Subject: [PATCH] [mlir][Vector] Remove usage of
 `vector.insertelement/extractelement` from Vector

This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops.
RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

It removes instances of `vector.extractelement` and `vector.insertelement`
from the Vector dialect layer.
---
 .../Vector/Transforms/VectorRewritePatterns.h |   4 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  13 +-
 .../Vector/Transforms/VectorDistribute.cpp    |  51 +-------
 .../Transforms/VectorTransferOpTransforms.cpp | 119 ++++--------------
 .../Conversion/VectorToSCF/vector-to-scf.mlir |  14 +--
 .../scalar-vector-transfer-to-memref.mlir     |   6 +-
 .../Vector/vector-warp-distribute.mlir        |  48 +++----
 7 files changed, 71 insertions(+), 184 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 34a94e6ea7051..ec0f856cb3f5a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -233,8 +233,8 @@ void populateBreakDownVectorReductionPatterns(
 ///
 /// [DecomposeNDExtractStridedSlice]
 /// ================================
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..45059f19a95c4 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -760,8 +760,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
 
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
-      // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // non-constant value.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -824,8 +823,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1567,7 +1565,7 @@ struct Strategy1d<TransferReadOp> {
         /*inBoundsCase=*/
         [&](OpBuilder &b, Location loc) {
           Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1595,8 +1593,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
         });
     b.create<scf::YieldOp>(loc);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..90970ae53defc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1255,27 +1255,6 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
-/// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement : public WarpDistributionPattern {
-  using Base::Base;
-  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
-    if (!operand)
-      return failure();
-    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    SmallVector<OpFoldResult> indices;
-    if (auto pos = extractOp.getPosition()) {
-      indices.push_back(pos);
-    }
-    rewriter.setInsertionPoint(extractOp);
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, extractOp.getVector(), indices);
-    return success();
-  }
-};
-
 /// Pattern to move out vector.insert with a scalar input.
 /// Only supports 1-D and 0-D destinations for now.
 struct WarpOpInsertScalar : public WarpDistributionPattern {
@@ -1483,26 +1462,6 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
-struct WarpOpInsertElement : public WarpDistributionPattern {
-  using Base::Base;
-  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
-    if (!operand)
-      return failure();
-    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
-    SmallVector<OpFoldResult> indices;
-    if (auto pos = insertOp.getPosition()) {
-      indices.push_back(pos);
-    }
-    rewriter.setInsertionPoint(insertOp);
-    rewriter.replaceOpWithNewOp<vector::InsertOp>(
-        insertOp, insertOp.getSource(), insertOp.getDest(), indices);
-    return success();
-  }
-};
-
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't
 /// change the order of execution. This creates a new scf.for region after the
@@ -1761,11 +1720,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..62e7f7cc61f6c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
   unsigned targetVectorBitwidth;
 };
 
-/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
-/// to `memref.load` patterns. The `match` method is shared for both
-/// `vector.extract` and `vector.extract_element`.
-template <class VectorExtractOp>
-class RewriteScalarExtractOfTransferReadBase
-    : public OpRewritePattern<VectorExtractOp> {
-  using Base = OpRewritePattern<VectorExtractOp>;
-
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be `vector.extract` ops. If
+/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
+/// users. Otherwise, rewrite only if the extract op is the single user of the
+/// transfer op. Rewriting a single vector load with multiple scalar loads may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+    : public OpRewritePattern<vector::ExtractOp> {
 public:
-  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
-                                         PatternBenefit benefit,
-                                         bool allowMultipleUses)
-      : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
-
-  LogicalResult match(VectorExtractOp extractOp) const {
-    auto xferOp =
-        extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
+  RewriteScalarExtractOfTransferRead(MLIRContext *context,
+                                     PatternBenefit benefit,
+                                     bool allowMultipleUses)
+      : OpRewritePattern(context, benefit),
+        allowMultipleUses(allowMultipleUses) {}
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Match phase.
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     if (!xferOp)
       return failure();
     // Check that we are extracting a scalar and not a sub-vector.
@@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
     // If multiple uses are allowed, check if all the xfer uses are extract ops.
     if (allowMultipleUses &&
         !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
-          return isa<vector::ExtractOp, vector::ExtractElementOp>(
-              use.getOwner());
+          return isa<vector::ExtractOp>(use.getOwner());
         }))
       return failure();
     // Mask not supported.
@@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
     // Cannot rewrite if the indices may be out of bounds.
     if (xferOp.hasOutOfBoundsDim())
       return failure();
-    return success();
-  }
-
-private:
-  bool allowMultipleUses;
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
-    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
-  using RewriteScalarExtractOfTransferReadBase::
-      RewriteScalarExtractOfTransferReadBase;
-
-  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(match(extractOp)))
-      return failure();
-
-    // Construct scalar load.
-    auto loc = extractOp.getLoc();
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
-    SmallVector<Value> newIndices(xferOp.getIndices().begin(),
-                                  xferOp.getIndices().end());
-    if (extractOp.getPosition()) {
-      AffineExpr sym0, sym1;
-      bindSymbols(extractOp.getContext(), sym0, sym1);
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, sym0 + sym1,
-          {newIndices[newIndices.size() - 1], extractOp.getPosition()});
-      if (auto value = dyn_cast<Value>(ofr)) {
-        newIndices[newIndices.size() - 1] = value;
-      } else {
-        newIndices[newIndices.size() - 1] =
-            rewriter.create<arith::ConstantIndexOp>(loc,
-                                                    *getConstantIntValue(ofr));
-      }
-    }
-    if (isa<MemRefType>(xferOp.getBase().getType())) {
-      rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
-                                                  newIndices);
-    } else {
-      rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
-          extractOp, xferOp.getBase(), newIndices);
-    }
-
-    return success();
-  }
-};
-
-/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
-/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
-///
-/// All the users of the transfer op must be either `vector.extractelement` or
-/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
-/// transfer ops with any number of users. Otherwise, rewrite only if the
-/// extract op is the single user of the transfer op. Rewriting a single
-/// vector load with multiple scalar loads may negatively affect performance.
-class RewriteScalarExtractOfTransferRead
-    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
-  using RewriteScalarExtractOfTransferReadBase::
-      RewriteScalarExtractOfTransferReadBase;
-
-  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(match(extractOp)))
-      return failure();
 
-    // Construct scalar load.
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+    // Rewrite phase: construct scalar load.
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
@@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead
 
     return success();
   }
+
+private:
+  bool allowMultipleUses;
 };
 
 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     bool allowMultipleUses) {
-  patterns.add<RewriteScalarExtractElementOfTransferRead,
-               RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+  patterns.add<RewriteScalarExtractOfTransferRead>(patterns.getContext(),
                                                    benefit, allowMultipleUses);
   patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a..33177736eb5fe 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
       // Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
       // CHECK: scf.if
       // CHECK-NEXT: memref.load
-      // CHECK-NEXT: vector.insertelement
+      // CHECK-NEXT: vector.insert
       // CHECK-NEXT: scf.yield
       // CHECK-NEXT: else
       // CHECK-NEXT: scf.yield
@@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK:                       %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
   // CHECK:                       scf.if {{.*}} -> (vector<3xf32>) {
   // CHECK-NEXT:                    %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
-  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
+  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
   // CHECK-NEXT:                    scf.yield
   // CHECK-NEXT:                  } else {
   // CHECK-NEXT:                    scf.yield
@@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
 // CHECK:           scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
-// CHECK:             %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
+// CHECK:             %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
 // CHECK:             scf.if %[[MASK_VAL]] {
-// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
+// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
 // CHECK:               memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
 // CHECK:             } else {
 // CHECK:             }
@@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
 // CHECK:             vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
@@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 // CHECK:             scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
 // CHECK:               %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
 // CHECK:               %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
-// CHECK:               %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
+// CHECK:               %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[FLAT_INDEX]]] : f32 from vector<4xf32>
 // CHECK:               vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:               %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
 // CHECK:               scf.if %[[IS_NOT_LAST_J]] {
@@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
 // CHECK:           %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[VEC]][%[[IDX]]] : i32 from vector<[4]xi32>
 // CHECK:             vector.print %[[EL]] : i32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7a1d6b3a8344a..7fec1c6ba5642 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -8,7 +8,7 @@
 func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
   %cst = arith.constant 0.0 : f32
   %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
-  %1 = vector.extractelement %0[] : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
   return %1 : f32
 }
 
@@ -24,7 +24,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
   %cst = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
   %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
-  %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+  %1 = vector.extract %0[%idx2] : f32 from vector<5xf32>
   return %1 : f32
 }
 
@@ -37,7 +37,7 @@ func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) ->
 func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
   %cst = arith.constant 0.0 : f32
   %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
-  %1 = vector.extractelement %0[] : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
   return %1 : f32
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..1dc0d5c7f9e1c 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -663,7 +663,7 @@ func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref
   gpu.warp_execute_on_lane_0(%laneid)[32] {
     %0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
     %1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
-    %2 = vector.extractelement %1[] : vector<f32>
+    %2 = vector.extract %1[] : f32 from vector<f32>
     %3 = vector.reduction <add>, %0 : vector<32xf32> into f32
     %4 = arith.addf %3, %2 : f32
     %5 = vector.broadcast %4 : f32 to vector<f32>
@@ -868,17 +868,17 @@ func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
 
 // -----
 
-// CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
+// CHECK-PROP-LABEL: func.func @vector_extract_0d(
 //       CHECK-PROP:   %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>
 //       CHECK-PROP:     gpu.yield %[[V]] : vector<f32>
 //       CHECK-PROP:   }
 //       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][] : f32 from vector<f32>
 //       CHECK-PROP:   return %[[E]] : f32
-func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
+func.func @vector_extract_0d(%laneid: index) -> (f32) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
     %0 = "some_def"() : () -> (vector<f32>)
-    %1 = vector.extractelement %0[] : vector<f32>
+    %1 = vector.extract %0[] : f32 from vector<f32>
     gpu.yield %1 : f32
   }
   return %r : f32
@@ -886,18 +886,18 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
 
 // -----
 
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+// CHECK-PROP-LABEL: func.func @vector_extract_1element(
 //       CHECK-PROP:   %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
 //       CHECK-PROP:     gpu.yield %[[V]] : vector<1xf32>
 //       CHECK-PROP:   }
 //       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][0] : f32 from vector<1xf32>
 //       CHECK-PROP:   return %[[E]] : f32
-func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+func.func @vector_extract_1element(%laneid: index) -> (f32) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
     %0 = "some_def"() : () -> (vector<1xf32>)
     %c0 = arith.constant 0 : index
-    %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+    %1 = vector.extract %0[%c0] : f32 from vector<1xf32>
     gpu.yield %1 : f32
   }
   return %r : f32
@@ -907,7 +907,7 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
 
 //       CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
 //       CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(
+// CHECK-PROP-LABEL: func.func @vector_extract_1d(
 //  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
 //   CHECK-PROP-DAG:   %[[C32:.*]] = arith.constant 32 : i32
 //       CHECK-PROP:   %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) {
@@ -920,10 +920,10 @@ func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
 //       CHECK-PROP:   %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
 //       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
 //       CHECK-PROP:   return %[[SHUFFLED]]
-func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
+func.func @vector_extract_1d(%laneid: index, %pos: index) -> (f32) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
     %0 = "some_def"() : () -> (vector<96xf32>)
-    %1 = vector.extractelement %0[%pos : index] : vector<96xf32>
+    %1 = vector.extract %0[%pos] : f32 from vector<96xf32>
     gpu.yield %1 : f32
   }
   return %r : f32
@@ -933,16 +933,16 @@ func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
 
 // Index-typed values cannot be shuffled at the moment.
 
-// CHECK-PROP-LABEL: func.func @vector_extractelement_1d_index(
+// CHECK-PROP-LABEL: func.func @vector_extract_1d_index(
 //       CHECK-PROP:   gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (index) {
 //       CHECK-PROP:     "some_def"
 //       CHECK-PROP:     vector.extract
 //       CHECK-PROP:     gpu.yield {{.*}} : index
 //       CHECK-PROP:   }
-func.func @vector_extractelement_1d_index(%laneid: index, %pos: index) -> (index) {
+func.func @vector_extract_1d_index(%laneid: index, %pos: index) -> (index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (index) {
     %0 = "some_def"() : () -> (vector<96xindex>)
-    %1 = vector.extractelement %0[%pos : index] : vector<96xindex>
+    %1 = vector.extract %0[%pos] : index from vector<96xindex>
     gpu.yield %1 : index
   }
   return %r : index
@@ -1142,7 +1142,7 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
 
 //       CHECK-PROP:   #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
 //       CHECK-PROP:   #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
-// CHECK-PROP-LABEL: func @vector_insertelement_1d(
+// CHECK-PROP-LABEL: func @_1d(
 //  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
 //       CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
 //       CHECK-PROP:   %[[INSERTING_LANE:.*]] = affine.apply #[[$MAP]]()[%[[POS]]]
@@ -1155,11 +1155,11 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
 //       CHECK-PROP:     scf.yield %[[W]]#0
 //       CHECK-PROP:   }
 //       CHECK-PROP:   return %[[R]]
-func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
+func.func @_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
     %0 = "some_def"() : () -> (vector<96xf32>)
     %f = "another_def"() : () -> (f32)
-    %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+    %1 = vector.insert %f, %0[%pos] : f32 into vector<96xf32>
     gpu.yield %1 : vector<96xf32>
   }
   return %r : vector<3xf32>
@@ -1167,18 +1167,18 @@ func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32
 
 // -----
 
-// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast(
+// CHECK-PROP-LABEL: func @_1d_broadcast(
 //  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
 //       CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, f32)
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
 //       CHECK-PROP:     gpu.yield %[[VEC]], %[[VAL]]
 //       CHECK-PROP:   vector.insert %[[W]]#1, %[[W]]#0 [%[[POS]]] : f32 into vector<96xf32>
-func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
+func.func @_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
     %0 = "some_def"() : () -> (vector<96xf32>)
     %f = "another_def"() : () -> (f32)
-    %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+    %1 = vector.insert %f, %0[%pos] : f32 into vector<96xf32>
     gpu.yield %1 : vector<96xf32>
   }
   return %r : vector<96xf32>
@@ -1186,17 +1186,17 @@ func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (ve
 
 // -----
 
-// CHECK-PROP-LABEL: func @vector_insertelement_0d(
+// CHECK-PROP-LABEL: func @_0d(
 //       CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32)
 //       CHECK-PROP:     %[[VEC:.*]] = "some_def"
 //       CHECK-PROP:     %[[VAL:.*]] = "another_def"
 //       CHECK-PROP:     gpu.yield %[[VEC]], %[[VAL]]
 //       CHECK-PROP:   vector.insert %[[W]]#1, %[[W]]#0 [] : f32 into vector<f32>
-func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
+func.func @_0d(%laneid: index) -> (vector<f32>) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
     %0 = "some_def"() : () -> (vector<f32>)
     %f = "another_def"() : () -> (f32)
-    %1 = vector.insertelement %f, %0[] : vector<f32>
+    %1 = vector.insert %f, %0[] : f32 into vector<f32>
     gpu.yield %1 : vector<f32>
   }
   return %r : vector<f32>
@@ -1299,7 +1299,7 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
 // -----
 
 // Make sure that all operands of the transfer_read op are properly propagated.
-// The vector.extractelement op cannot be propagated because index-typed
+// The vector.extract op cannot be propagated because index-typed
 // shuffles are not supported at the moment.
 
 // CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 2)>
@@ -1333,7 +1333,7 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 :  memref<1
     %28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
     %29 = vector.extract %28[0] : vector<64xi32> from vector<1x64xi32>
     %30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
-    %36 = vector.extractelement %30[%c0_i32 : index] : vector<64xindex>
+    %36 = vector.extract %30[%c0_i32] : index from vector<64xindex>
     %37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
     gpu.yield %37 : vector<64xf32>
   }



More information about the Mlir-commits mailing list