[Mlir-commits] [mlir] [mlir][Vector] Remove uses of vector.extractelement/vector.insertelement (PR #113827)

Kunwar Grover llvmlistbot at llvm.org
Sun Oct 27 11:17:24 PDT 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/113827

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.

>From c56479dbb9e746057c58fb640e6504152c8990bc Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 27 Oct 2024 18:14:07 +0000
Subject: [PATCH 1/2] [mlir][Vector] Fix vector.insert folder for scalar to 0-d
 inserts

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  8 ++++----
 mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..03d2409f42c524 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6d6bc199e601c0..580daa2a13d15e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2745,6 +2745,18 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
 
 // -----
 
+// CHECK-LABEL: func @insert_into_0d_regression(
+//  CHECK-SAME:     %[[v:.*]]: vector<f32>)
+//       CHECK:   %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
+//       CHECK:   return %[[extract]]
+func.func @insert_into_0d_regression(%v: vector<f32>) -> vector<f32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.insert %cst, %v [] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
 // CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>

>From 1ca7bbb28ef36d987d721c5e359d0351c48342e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 27 Oct 2024 18:14:30 +0000
Subject: [PATCH 2/2] [mlir][Vector] Remove uses of
 vector.extractelement/vector.insertelement

---
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  9 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 19 ++--
 .../Linalg/Transforms/Vectorization.cpp       | 10 +--
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 15 +++-
 .../Transforms/LowerVectorBroadcast.cpp       |  3 +-
 .../Transforms/LowerVectorMultiReduction.cpp  |  5 +-
 .../Transforms/LowerVectorShapeCast.cpp       | 20 +----
 .../Vector/Transforms/VectorDistribute.cpp    | 12 ++-
 ...sertExtractStridedSliceRewritePatterns.cpp | 18 +---
 .../ArithToAMDGPU/16-bit-floats.mlir          | 24 ++---
 .../VectorToLLVM/vector-to-llvm.mlir          |  3 +-
 .../Conversion/VectorToSCF/vector-to-scf.mlir | 14 +--
 .../Linalg/vectorization-scalable.mlir        |  2 +-
 .../Linalg/vectorization-with-patterns.mlir   | 10 +--
 .../vectorize-tensor-extract-masked.mlir      | 12 +--
 .../Linalg/vectorize-tensor-extract.mlir      | 55 ++++++------
 mlir/test/Dialect/Vector/canonicalize.mlir    |  6 +-
 .../vector-multi-reduction-lowering.mlir      | 89 +++++++------------
 .../vector-multi-reduction-pass-lowering.mlir |  6 +-
 ...vector-shape-cast-lowering-transforms.mlir |  4 +-
 20 files changed, 138 insertions(+), 198 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   for (int64_t i = 0; i < numElements; i += 2) {
     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
     Value thisResult = nullptr;
-    Value elemA = rewriter.create<vector::ExtractElementOp>(
-        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
 
     if (elemsThisOp == 2) {
-      elemB = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+      elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
     }
 
     thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
 /// %lastIndex = arith.subi %length, %c1 : index
 /// vector.print punctuation <open>
 /// scf.for %i = %c0 to %length step %c1 {
-///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+///   %el = vector.extract %v[%i : index] : vector<[4]xi32>
 ///   vector.print %el : i32 punctuation <no_punctuation>
 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
 ///   scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // vector.extract for 1D vectors).
+      // TODO: vector.extract supports N-D non-constant indices now.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
         [&](OpBuilder &b, Location loc) {
           Value val =
               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
         });
     b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
 /// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
 ///    depending on OpTy.
 ///
 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
 /// Is rewritten to approximately the following pseudo-IR:
 /// ```
 /// for i = 0 to 9 {
-///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   %t = vector.extract %vec[i] : vector<9xf32>
 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                         resultType.getScalableDims().back()),
         idx);
     transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+        rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                     SmallVector<int64_t>{});
 
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                   SmallVector<int64_t>{});
     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 03d2409f42c524..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     Value result;
     if (vectorType.getRank() == 0) {
       if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+        mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+                                          SmallVector<int64_t>{});
     } else {
       if (mask)
         mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     if (extractResultRank < broadcastSrcRank)
       return failure();
 
-    // Special case if broadcast src is a 0D vector.
+    // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+    // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+    // to the broadcast will be a vector<f32> or f32, but the result will be a
+    // f32, because of vector.extract 0-d semantics. Therefore, we instead
+    // just replace the broadcast with a vector.extract.
     if (extractResultRank == 0) {
       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+                                                     SmallVector<int64_t>{});
       return success();
     }
+
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     if (srcRank <= 1 && dstRank == 1) {
       Value ext;
       if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+                                                 SmallVector<int64_t>{});
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
       }
 
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, reductionOp->getResult(0), result,
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
+      result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+                                                 result, i);
     }
 
     rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       }
 
       Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (extractOp.getNumIndices() == 0)
       return failure();
 
-    // Rewrite vector.extract with 1d source to vector.extractelement.
+    // Rewrite vector.extract with 1d source to vector.extract.
     if (extractSrcType.getRank() == 1) {
       if (extractOp.hasDynamicPosition())
         // TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
       return success();
     }
 
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
                        Value into, int64_t offset) {
-  auto vectorType = cast<VectorType>(into.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<InsertOp>(loc, from, into, offset);
 }
 
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
                         int64_t offset) {
-  auto vectorType = cast<VectorType>(vector.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<ExtractOp>(loc, vector, offset);
 }
 
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
 };
 
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
 func.func @scalar_trunc(%v: f32) -> f16{
   // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
   // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
-  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
   // CHECK: return %[[extract]] : f16
   %w = arith.truncf %v : f32 to f16
   return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
 // CHECK-LABEL: @vector_trunc
 // CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
 func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
   // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: return %[[ret]]
   %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
 // CHECK-LABEL:  @vector_trunc_long
 // CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
 func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
   // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
-  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+  // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
   // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
   // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
-  // CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
+  // CHECK: %[[elem4:.*]] = vector.extract %[[value]][4]
+  // CHECK: %[[elem5:.*]] = vector.extract %[[value]][5]
   // CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
   // CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem6:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem7:.*]] = vector.extract %[[value]]
   // CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
   // CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem8:.*]] = vector.extract %[[value]]
   // CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
   // CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
   // CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index eb6da71b063273..0d29e848f57861 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -233,10 +233,9 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
 // CHECK-LABEL: @broadcast_vec2d_from_vec0d(
 // CHECK-SAME:  %[[A:.*]]: vector<f32>)
 //       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
+//       CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to f32
 //       CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
 //       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-//       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
-//       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
 //       CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
 //       CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
 //       CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a5..acd62c993919ec 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -37,7 +37,7 @@ func.func @materialize_read_1d() {
       // Both accesses in the load must be clipped otherwise %i1 + 2 and %i1 + 3 will go out of bounds.
       // CHECK: scf.if
       // CHECK-NEXT: memref.load
-      // CHECK-NEXT: vector.insertelement
+      // CHECK-NEXT: vector.insert
       // CHECK-NEXT: scf.yield
       // CHECK-NEXT: else
       // CHECK-NEXT: scf.yield
@@ -103,7 +103,7 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK:                       %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
   // CHECK:                       scf.if {{.*}} -> (vector<3xf32>) {
   // CHECK-NEXT:                    %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
-  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
+  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insert %[[SCAL]], %{{.*}} [%[[I6]]] : f32 into vector<3xf32>
   // CHECK-NEXT:                    scf.yield
   // CHECK-NEXT:                  } else {
   // CHECK-NEXT:                    scf.yield
@@ -540,9 +540,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
 // CHECK:           scf.for %[[IDX:.*]] = %[[C_0]] to %[[UB]] step %[[STEP]] {
-// CHECK:             %[[MASK_VAL:.*]] = vector.extractelement %[[MASK_VEC]][%[[IDX]] : index] : vector<[16]xi1>
+// CHECK:             %[[MASK_VAL:.*]] = vector.extract %[[MASK_VEC]][%[[IDX]]] : i1 from vector<[16]xi1>
 // CHECK:             scf.if %[[MASK_VAL]] {
-// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extractelement %{{.*}}[%[[IDX]] : index] : vector<[16]xf32>
+// CHECK:               %[[VAL_TO_STORE:.*]] = vector.extract %{{.*}}[%[[IDX]]] : f32 from vector<[16]xf32>
 // CHECK:               memref.store %[[VAL_TO_STORE]], %[[ARG_0]][%[[IDX]]] : memref<?xf32, strided<[?], offset: ?>>
 // CHECK:             } else {
 // CHECK:             }
@@ -561,7 +561,7 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[IDX]] : index] : vector<1xf32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[FLAT_VEC]]{{\[}}%[[IDX]]] : f32 from vector<1xf32>
 // CHECK:             vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
@@ -591,7 +591,7 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 // CHECK:             scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
 // CHECK:               %[[OUTER_INDEX:.*]] = arith.muli %[[I]], %[[C2]] : index
 // CHECK:               %[[FLAT_INDEX:.*]] = arith.addi %[[J]], %[[OUTER_INDEX]] : index
-// CHECK:               %[[EL:.*]] = vector.extractelement %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]] : index] : vector<4xf32>
+// CHECK:               %[[EL:.*]] = vector.extract %[[FLAT_VEC]]{{\[}}%[[FLAT_INDEX]]] : f32 from vector<4xf32>
 // CHECK:               vector.print %[[EL]] : f32 punctuation <no_punctuation>
 // CHECK:               %[[IS_NOT_LAST_J:.*]] = arith.cmpi ult, %[[J]], %[[C1]] : index
 // CHECK:               scf.if %[[IS_NOT_LAST_J]] {
@@ -625,7 +625,7 @@ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
 // CHECK:           %[[LAST_INDEX:.*]] = arith.subi %[[UPPER_BOUND]], %[[C1]] : index
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[UPPER_BOUND]] step %[[C1]] {
-// CHECK:             %[[EL:.*]] = vector.extractelement %[[VEC]]{{\[}}%[[IDX]] : index] : vector<[4]xi32>
+// CHECK:             %[[EL:.*]] = vector.extract %[[VEC]]{{\[}}%[[IDX]]] : i32 from vector<[4]xi32>
 // CHECK:             vector.print %[[EL]] : i32 punctuation <no_punctuation>
 // CHECK:             %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[LAST_INDEX]] : index
 // CHECK:             scf.if %[[IS_NOT_LAST]] {
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index c3a30e3ee209e8..c5d47946e6fc50 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -210,7 +210,7 @@ func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
 // CHECK:          %[[VEC_RD_0:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
 // CHECK:          %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:          %[[VEC_RD_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[C0_F32]] : tensor<f32>, vector<f32>
-// CHECK:          %[[ACC_f32:.*]] = vector.extractelement %[[VEC_RD_1]][] : vector<f32>
+// CHECK:          %[[ACC_f32:.*]] = vector.extract %[[VEC_RD_1]][] : f32 from vector<f32>
 // CHECK:          %[[REDUCE:.*]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[ACC_f32]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
 // CHECK:          %[[VEC_f32:.*]] = vector.broadcast %[[REDUCE]] : f32 to vector<f32>
 // CHECK:          %{{.*}} = vector.transfer_write %[[VEC_f32]], %[[ARG_1]][] : vector<f32>, tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 189507d97d6dc2..fa1c8ffc6d3559 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -414,7 +414,7 @@ module attributes {transform.with_named_sequence} {
 func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
   //  CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
   //       CHECK:   %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
-  //       CHECK:   %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
+  //       CHECK:   %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
   //       CHECK:   %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
   //       CHECK:   vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
   memref.copy %A, %B :  memref<f32> to memref<f32>
@@ -1436,7 +1436,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-LABEL: func @reduce_1d(
 //   CHECK-SAME:   %[[A:.*]]: tensor<32xf32>
 func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
-  //  CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
   //  CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
   //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %f0 = arith.constant 0.000000e+00 : f32
@@ -1447,8 +1446,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
-  //      CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
-  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
+  //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0]
   // CHECK-SAME:   : vector<32xf32> to f32
   //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
   //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
@@ -1775,9 +1773,9 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func @zero_dim_tensor
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extractelement
+//       CHECK:     vector.extract
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extractelement
+//       CHECK:     vector.extract
 //       CHECK:     arith.addf {{.*}} : f32
 //       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
 //       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 31a754d9343682..74d23fb5b1e3e1 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -37,11 +37,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
 // CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
 // CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
 
 /// Extract the starting point from the index vector
-// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+// CHECK:           %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>
 
 // Final read and write
 // CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
@@ -98,11 +97,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
 // CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
 // CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
 
 /// Extract the starting point from the index vector
-// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+// CHECK:           %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>
 
 // Final read and write
 // CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
@@ -159,11 +157,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
 // CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
 // CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
 
 /// Extract the starting point from the index vector
-// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+// CHECK:           %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>
 
 // Final read and write
 // CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
@@ -218,11 +215,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
 // CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
 // CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
 // CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
-// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
 
 /// Extract the starting point from the index vector
-// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+// CHECK:           %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>
 
 // Final read and write
 // CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e611a8e22ee23f..c02405f29bcf7b 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -125,15 +125,17 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
 // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
-// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+
+// CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:  %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
+// CHECK-DAG:  %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+
+// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
+// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
+// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
+
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
 // Same as example above, but reading into a column tensor.
@@ -203,20 +205,18 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
 // CHECK-SAME:      %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
 // CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 79 : index
 // CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
 // CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
-// CHECK:           %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+
+// CHECK:           %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
+
+// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_21]] : tensor<1x4xf32>
 // CHECK:         }
@@ -451,7 +451,7 @@ func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<true> : vector<5xi1>
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<5xf32>
 // CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<5xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%{{.*}}], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
 // CHECK:           %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : vector<5xi32> to vector<5xindex>
 // CHECK:           %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : vector<5xindex>
 // CHECK:           %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_5]] : vector<5xindex>
@@ -491,13 +491,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
 // CHECK-SAME:                                                                        %[[VAL_1:.*]]: index,
 // CHECK-SAME:                                                                        %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 79 : index
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK:           %[[VAL_10:.*]] = vector.extractelement %[[VAL_9]]{{\[}}%[[VAL_4]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
 // CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_12]] : tensor<1x4xf32>
@@ -538,10 +537,11 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_tensor_extract(
 // CHECK-SAME:      %[[INPUT_1:.*]]: tensor<1x20xi32>,
 // CHECK-SAME:      %[[INPUT_2:.*]]: tensor<257x24xf32>,
+// CHECK-SAME:      %[[INPUT_3:.*]]: index, %[[INPUT_4:.*]]: index, %[[INPUT_5:.*]]: index,
 // CHECK:           %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
+// CHECK:           %[[SCALAR:.*]] = arith.addi %[[INPUT_3]], %[[INPUT_5]] : index
 // First `vector.transfer_read` from the generic Op - loop invariant scalar load.
-// CHECK:           vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] 
+// CHECK:           vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[SCALAR]]] 
 // CHECK-SAME:      tensor<1x20xi32>, vector<i32>
 // The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
 // for address calculation also satisfy the required conditions).
@@ -667,13 +667,15 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
 // CHECK-SAME:                                                                 %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                 %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<16> : vector<1x4xindex>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_6:.*]] = vector.shape_cast %[[VAL_2]] : vector<1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_7:.*]] = vector.extractelement %[[VAL_6]]{{\[}}%[[VAL_3]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_4]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+
+// CHECK-DAG:       %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG:       %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
+// CHECK-DAG:       %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
+// CHECK-DAG:       %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_9]] : tensor<1x4xf32>
 // CHECK:         }
@@ -842,9 +844,8 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
 // CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
 // CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
 // CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_19:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
 // CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
 // CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 580daa2a13d15e..51c411a76a260f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -665,7 +665,7 @@ func.func @fold_extract_broadcast(%a : f32) -> f32 {
 
 // CHECK-LABEL: fold_extract_broadcast_0dvec
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
-//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
@@ -2442,7 +2442,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
 
 // CHECK-LABEL: func.func @fold_0d_vector_reduction
 func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
-  // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
+  // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
   // CHECK-NEXT: return %[[RES]] : f32
   %0 = vector.reduction <add>, %arg0 : vector<f32> into f32
   return %0 : f32
@@ -2629,7 +2629,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %3 = vector.extract %2[] : f32 from vector<f32>
 
   // Broadcast 0D to 3D and extract scalar.
-  // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+  // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
   %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
   %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
 
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 6e93923608cbf2..915154d00778c1 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -7,16 +7,14 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 // CHECK-LABEL: func @vector_multi_reduction
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
 //   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
 //       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
 //       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0]
 //       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
 //       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][1]
 //       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
 func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
@@ -27,9 +25,7 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
 //       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
 //       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-//       CHECK:   %[[INSERTED:.*]] = vector.insertelement %[[REDUCED]], {{.*}} : vector<1xf32>
-//       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : f32 from vector<1xf32>
-//       CHECK:   return %[[RES]]
+//       CHECK:   return %[[REDUCED]]
 
 func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
@@ -38,37 +34,31 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
 // CHECK-LABEL: func @vector_reduction_inner
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>, %[[ACC:.*]]: vector<2x3xi32>
 //   CHECK-DAG:       %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
-//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-DAG:       %[[C3:.+]] = arith.constant 3 : index
-//   CHECK-DAG:       %[[C4:.+]] = arith.constant 4 : index
-//   CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : index
 //       CHECK:       %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V0R:.+]] = vector.reduction <add>, %[[V0]], %[[ACC0]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insert %[[V0R]], %[[FLAT_RESULT_VEC_0]] [0] : i32 into vector<6xi32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V1R:.+]] = vector.reduction <add>, %[[V1]], %[[ACC1]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insert %[[V1R]], %[[FLAT_RESULT_VEC_1]] [1] : i32 into vector<6xi32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V2R:.+]] = vector.reduction <add>, %[[V2]], %[[ACC2]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insert %[[V2R]], %[[FLAT_RESULT_VEC_2]] [2] : i32 into vector<6xi32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][1, 0] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V3R:.+]] = vector.reduction <add>, %[[V3]], %[[ACC3]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insert %[[V3R]], %[[FLAT_RESULT_VEC_3]] [3] : i32 into vector<6xi32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 1] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V4R:.+]] = vector.reduction <add>, %[[V4]], %[[ACC4]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insert %[[V4R]], %[[FLAT_RESULT_VEC_4]] [4] : i32 into vector<6xi32>
 //       CHECK:       %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<20xi32> from vector<6x20xi32>
 //       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 2] : i32 from vector<2x3xi32>
 //       CHECK:       %[[V5R:.+]] = vector.reduction <add>, %[[V5]], %[[ACC5]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insert %[[V5R]], %[[FLAT_RESULT_VEC_5]] [5] : i32 into vector<6xi32>
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
@@ -91,47 +81,39 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto
 // CHECK-LABEL: func @vector_multi_reduction_ordering
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>, %[[ACC:.*]]: vector<2x4xf32>)
 //   CHECK-DAG:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
-//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
-//   CHECK-DAG:       %[[C3:.+]] = arith.constant 3 : index
-//   CHECK-DAG:       %[[C4:.+]] = arith.constant 4 : index
-//   CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : index
-//   CHECK-DAG:       %[[C6:.+]] = arith.constant 6 : index
-//   CHECK-DAG:       %[[C7:.+]] = arith.constant 7 : index
 //       CHECK:       %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
 //       CHECK:       %[[ACC0:.+]] = vector.extract %[[ACC]][0, 0] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<8xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
 //       CHECK:       %[[ACC1:.+]] = vector.extract %[[ACC]][0, 1] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<8xf32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
 //       CHECK:       %[[ACC2:.+]] = vector.extract %[[ACC]][0, 2] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV2:.+]] = vector.reduction <mul>, %[[V2]], %[[ACC2]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insert %[[RV2:.+]], %[[RESULT_VEC_2]] [2] : f32 into vector<8xf32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
 //       CHECK:       %[[ACC3:.+]] = vector.extract %[[ACC]][0, 3] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV3:.+]] = vector.reduction <mul>, %[[V3]], %[[ACC3]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insert %[[RV3:.+]], %[[RESULT_VEC_3]] [3] : f32 into vector<8xf32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
 //       CHECK:       %[[ACC4:.+]] = vector.extract %[[ACC]][1, 0] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV4:.+]] = vector.reduction <mul>, %[[V4]], %[[ACC4]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insert %[[RV4:.+]], %[[RESULT_VEC_4]] [4] : f32 into vector<8xf32>
 //       CHECK:       %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
 //       CHECK:       %[[ACC5:.+]] = vector.extract %[[ACC]][1, 1] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV5:.+]] = vector.reduction <mul>, %[[V5]], %[[ACC5]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insert %[[RV5:.+]], %[[RESULT_VEC_5]] [5] : f32 into vector<8xf32>
 //       CHECK:       %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
 //       CHECK:       %[[ACC6:.+]] = vector.extract %[[ACC]][1, 2] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV6:.+]] = vector.reduction <mul>, %[[V6]], %[[ACC6]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insert %[[RV6:.+]], %[[RESULT_VEC_6]] [6] : f32 into vector<8xf32>
 //       CHECK:       %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
 //       CHECK:       %[[ACC7:.+]] = vector.extract %[[ACC]][1, 3] : f32 from vector<2x4xf32>
 //       CHECK:       %[[RV7:.+]] = vector.reduction <mul>, %[[V7]], %[[ACC7]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV7:.+]], %[[RESULT_VEC_7]] [7] : f32 into vector<8xf32>
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]
 
@@ -163,19 +145,19 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
 
 // CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<8xi1> from vector<4x8xi1>
 // CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK:           %[[VAL_18:.*]] = vector.insertelement
+// CHECK:           %[[VAL_18:.*]] = vector.insert
 
 // CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<8xi1> from vector<4x8xi1>
 // CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK:           %[[VAL_23:.*]] = vector.insertelement
+// CHECK:           %[[VAL_23:.*]] = vector.insert
 
 // CHECK:           %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<8xi1> from vector<4x8xi1>
 // CHECK:           %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK:           %[[VAL_28:.*]] = vector.insertelement
+// CHECK:           %[[VAL_28:.*]] = vector.insert
 
 // CHECK:           %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<8xi1> from vector<4x8xi1>
 // CHECK:           %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
-// CHECK:           %[[VAL_33:.*]] = vector.insertelement
+// CHECK:           %[[VAL_33:.*]] = vector.insert
 
 func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
   %c0 = arith.constant 0 : index
@@ -226,19 +208,19 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
 
 // CHECK:           %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<4xi1> from vector<8x16x4xi1>
 // CHECK:           %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
-// CHECK:           %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]]
+// CHECK:           %[[VAL_145:.*]] = vector.insert %[[VAL_144]]
 
 // CHECK:           %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<4xi1> from vector<8x16x4xi1>
 // CHECK:           %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
-// CHECK:           %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]]
+// CHECK:           %[[VAL_150:.*]] = vector.insert %[[VAL_149]]
 
 // CHECK:           %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<4xi1> from vector<8x16x4xi1>
 // CHECK:           %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
-// CHECK:           %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]]
+// CHECK:           %[[VAL_155:.*]] = vector.insert %[[VAL_154]]
 
 // CHECK:           %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<4xi1> from vector<8x16x4xi1>
 // CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
-// CHECK:           %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
+// CHECK:           %[[VAL_160:.*]] = vector.insert %[[VAL_159]]
 
 func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
@@ -257,26 +239,23 @@ func.func private @vector_multi_reduction_non_scalable_dim(%A : vector<8x[4]x2xf
 // CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
 // CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_34:.*]] = arith.constant 31 : index
 
 // CHECK:           %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2xf32> from vector<8x[4]x2xf32>
 // CHECK:           %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : f32 from vector<8x[4]xf32>
 // CHECK:           %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32
-// CHECK:           %[[VAL_38:.*]] = vector.insertelement %[[VAL_37]], %[[VAL_2]]{{\[}}%[[VAL_3]] : index] : vector<[32]xf32>
+// CHECK:           %[[VAL_38:.*]] = vector.insert %[[VAL_37]], %[[VAL_2]] [0] : f32 into vector<[32]xf32>
 
 // CHECK:           %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2xf32> from vector<8x[4]x2xf32>
 // CHECK:           %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : f32 from vector<8x[4]xf32>
 // CHECK:           %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32
-// CHECK:           %[[VAL_42:.*]] = vector.insertelement %[[VAL_41]], %[[VAL_38]]{{\[}}%[[VAL_4]] : index] : vector<[32]xf32>
+// CHECK:           %[[VAL_42:.*]] = vector.insert %[[VAL_41]], %[[VAL_38]] [1] : f32 into vector<[32]xf32>
 
 // (...)
 
 // CHECK:           %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<2xf32> from vector<8x[4]x2xf32>
 // CHECK:           %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : f32 from vector<8x[4]xf32>
 // CHECK:           %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32
-// CHECK:           %[[VAL_162:.*]] = vector.insertelement %[[VAL_161]], %{{.*}}{{\[}}%[[VAL_34]] : index] : vector<[32]xf32>
+// CHECK:           %[[VAL_162:.*]] = vector.insert %[[VAL_161]], %{{.*}} [31] : f32 into vector<[32]xf32>
 
 // CHECK:           %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
 // CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
@@ -291,12 +270,8 @@ func.func @vector_multi_reduction_scalable_dim_1d(%A: vector<[4]xf32>, %B: f32,
 // CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
 // CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
 // CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
-// CHECK-DAG:      %[[VAL_0:.*]] = arith.constant 0 : index
-// CHECK-DAG:      %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
 // CHECK:          %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK:          %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
-// CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
-// CHECK:          return %[[VAL_4]] : f32
+// CHECK:          return %[[VAL_2]] : f32
 
 func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
     %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
@@ -307,19 +282,17 @@ func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vec
 // CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
 // CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
 // CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
-// CHECK-DAG:      %[[C1_idx:.*]] = arith.constant 1 : index
-// CHECK-DAG:      %[[C0_idx:.*]] = arith.constant 0 : index
 // CHECK-DAG:      %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
 // CHECK:          %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
 // CHECK:          %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
 // CHECK:          %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
 // CHECK:          %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK:          %[[INSERT_0:.*]] = vector.insertelement %[[REDUCE_0]], %[[C0_2xf32]][%[[C0_idx]] : index] : vector<2xf32>
+// CHECK:          %[[INSERT_0:.*]] = vector.insert %[[REDUCE_0]], %[[C0_2xf32]] [0] : f32 into vector<2xf32>
 // CHECK:          %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
 // CHECK:          %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
 // CHECK:          %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
 // CHECK:          %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
-// CHECK:          %[[INSERT_1:.*]] = vector.insertelement %[[REDUCE_1]], %[[INSERT_0]][%[[C1_idx]] : index] : vector<2xf32>
+// CHECK:          %[[INSERT_1:.*]] = vector.insert %[[REDUCE_1]], %[[INSERT_0]] [1] : f32 into vector<2xf32>
 // CHECK:          return %[[INSERT_1]] : vector<2xf32>
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 4cb6fba9b691a6..68621ffaac3d20 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -9,16 +9,14 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //           ALL-LABEL: func @vector_multi_reduction
 //            ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: vector<2xf32>)
 // INNER-REDUCTION-DAG: %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-// INNER-REDUCTION-DAG: %[[C0:.+]] = arith.constant 0 : index
-// INNER-REDUCTION-DAG: %[[C1:.+]] = arith.constant 1 : index
 //     INNER-REDUCTION: %[[V0:.+]] = vector.extract %[[INPUT]][0]
 //     INNER-REDUCTION: %[[ACC0:.+]] = vector.extract %[[ACC]][0]
 //     INNER-REDUCTION: %[[RV0:.+]] = vector.reduction <mul>, %[[V0]], %[[ACC0]] : vector<4xf32> into f32
-//     INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
+//     INNER-REDUCTION: %[[RESULT_VEC_1:.+]] = vector.insert %[[RV0:.+]], %[[RESULT_VEC_0]] [0] : f32 into vector<2xf32>
 //     INNER-REDUCTION: %[[V1:.+]] = vector.extract %[[INPUT]][1]
 //     INNER-REDUCTION: %[[ACC1:.+]] = vector.extract %[[ACC]][1]
 //     INNER-REDUCTION: %[[RV1:.+]] = vector.reduction <mul>, %[[V1]], %[[ACC1]] : vector<4xf32> into f32
-//     INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
+//     INNER-REDUCTION: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
 //     INNER-REDUCTION: return %[[RESULT_VEC]]
 
 //      INNER-PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index f2f1211fd70eed..7cdbf25c428bd2 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -126,7 +126,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 // CHECK-LABEL:   func.func @shape_cast_0d1d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-// CHECK:           %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
+// CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][] : f32 from vector<f32>
 // CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
 // CHECK:           return %[[VAL_3]] : vector<1xf32>
 // CHECK:         }
@@ -140,7 +140,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
 // CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
-// CHECK:           %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
+// CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [] : f32 into vector<f32>
 // CHECK:           return %[[VAL_3]] : vector<f32>
 // CHECK:         }
 



More information about the Mlir-commits mailing list