[Mlir-commits] [mlir] [mlir][vector] Allow integer indices in vector.extract/insert ops (PR #115808)

Diego Caballero llvmlistbot at llvm.org
Sat Nov 16 14:10:48 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/115808

>From 8192475e41b4d57d361860410e1235a32ba718a1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Sat, 9 Nov 2024 22:11:20 -0800
Subject: [PATCH 1/2] [mlir][vector] Allow integer indices in
 vector.extract/insert ops

`vector.extract` and `vector.insert` can currently take an `i64` constant
or an `index` type value as indices. The `index` type will usually lower to
an `i32` or `i64` type. However, we are often indexing really small vector
dimensions where smaller integers could be used. This PR extends both
ops to accept any integer value as indices. For example:

```
  %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>
```

This led to some changes to the ops' parser and printer. When a value index is provided,
the index type is printed as part of the index list. All the value indices provided must
match that type. When no value index is provided, no index type is printed.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 23 +++--
 mlir/include/mlir/IR/OpImplementation.h       | 21 +++-
 .../mlir/Interfaces/ViewLikeInterface.h       | 29 +++++-
 mlir/lib/AsmParser/AsmParserImpl.h            | 23 ++++-
 mlir/lib/AsmParser/Parser.cpp                 | 11 ++-
 mlir/lib/AsmParser/Parser.h                   |  9 +-
 .../VectorToArmSME/VectorToArmSME.cpp         | 10 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  8 +-
 mlir/lib/Interfaces/ViewLikeInterface.cpp     | 46 +++++++--
 .../VectorToArmSME/unsupported.mlir           | 10 +-
 .../VectorToArmSME/vector-to-arm-sme.mlir     | 98 +++++++++----------
 .../VectorToLLVM/vector-to-llvm.mlir          | 16 +--
 .../Conversion/VectorToSCF/vector-to-scf.mlir |  8 +-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 12 +--
 .../Dialect/ArmSME/outer-product-fusion.mlir  |  4 +-
 .../Dialect/ArmSME/vector-legalization.mlir   | 26 ++---
 mlir/test/Dialect/Linalg/hoisting.mlir        |  4 +-
 .../Dialect/Linalg/transform-ops-invalid.mlir |  2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir    | 16 +--
 mlir/test/Dialect/Vector/invalid.mlir         | 65 ++++++++++++
 mlir/test/Dialect/Vector/ops.mlir             | 51 +++++++---
 .../vector-emulate-narrow-type-unaligned.mlir | 24 ++---
 22 files changed, 353 insertions(+), 163 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
     %3 = vector.extract %1[]: vector<f32> from vector<f32>
-    %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
-    %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+    %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+    %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyVectorOfAnyRank:$vector,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
 
   let assemblyFormat = [{
     $vector ``
-    custom<DynamicIndexList>($dynamic_position, $static_position)
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($result) `from` type($vector)
   }];
 
@@ -883,15 +884,15 @@ def Vector_InsertOp :
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
-    %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+    %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyType:$source,
     AnyVectorOfAnyRank:$dest,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    $source `,` $dest
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
           %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
           // Update the temporary gathered slice with the individual element
           %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
-          %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+          %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
           memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}}
     // At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
         %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
         %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
         // Here we only store to the first element in dimension one
-        %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+        %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
         memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}
     // At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
   };
 
   /// Parse a list of comma-separated items with an optional delimiter.  If a
-  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// delimiter is provided, then an empty list is allowed. If not, then at
   /// least one element will be parsed.
   ///
+  /// `parseSuffixFn` is an optional function to parse any suffix that can be
+  /// appended to the comma separated list within the delimiter.
+  ///
   /// contextMessage is an optional message appended to "expected '('" sorts of
   /// diagnostics when parsing the delimeters.
-  virtual ParseResult
+  virtual ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef()) = 0;
+  ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef()) = 0;
-
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn,
+                                   /*parseSuffixFn=*/std::nullopt,
+                                   contextMessage);
+  }
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
   ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
   virtual ParseResult
   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
+  /// Parse an optional colon followed by a type.
+  virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
 /// is non-empty, it is expected to contain as many elements as `values`
 /// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
 ///
 /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
 /// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                   OperandRange values,
                                   ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
   return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
                                delimiter);
 }
+inline void printSameTypeDynamicIndexList(
+    OpAsmPrinter &printer, Operation *op, OperandRange values,
+    ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+                               delimiter, /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Parser hook for custom directive in assemblyFormat.
 ///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline ParseResult
 parseDynamicIndexList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
   return parseDynamicIndexList(parser, values, integers, scalableVals,
                                &valueTypes, delimiter);
 }
+inline ParseResult parseSameTypeDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter,
+                               /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
-  ParseResult parseCommaSeparatedList(Delimiter delimiter,
-                                      function_ref<ParseResult()> parseElt,
-                                      StringRef contextMessage) override {
-    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElt,
+      std::optional<function_ref<ParseResult()>> parseSuffix,
+      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+                                          contextMessage);
   }
 
+  using BaseT::parseCommaSeparatedList;
+
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
     return parser.parseTypeListNoParens(result);
   }
 
+  /// Parse an optional colon followed by a type.
+  ParseResult parseOptionalColonType(Type &result) override {
+    SmallVector<Type, 1> types;
+    ParseResult parseResult = parseOptionalColonTypeList(types);
+    if (llvm::succeeded(parseResult) && types.size() > 1)
+      return emitError(getCurrentLocation(), "expected single type");
+    if (!types.empty())
+      result = types[0];
+    return parseResult;
+  }
+
   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                  bool allowDynamic,
                                  bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
 /// Parse a list of comma-separated items with an optional delimiter.  If a
 /// delimiter is provided, then an empty list is allowed.  If not, then at
 /// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
-                                function_ref<ParseResult()> parseElementFn,
-                                StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+    Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+    std::optional<function_ref<ParseResult()>> parseSuffixFn,
+    StringRef contextMessage) {
   switch (delimiter) {
   case Delimiter::None:
     break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
       return failure();
   }
 
+  if (parseSuffixFn && (*parseSuffixFn)())
+    return failure();
+
   switch (delimiter) {
   case Delimiter::None:
     return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef());
   ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef());
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+                                   contextMessage);
+  }
 
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
 ///
 /// Example:
 /// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
 /// ```
 /// Becomes:
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
 /// ```
 struct VectorExtractToArmSMELowering
     : public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
 ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
 /// ```
 struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
 /// %vscale = vector.vscale
 /// %c4_vscale = arith.muli %vscale, %c4 : index
 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
-///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
 ///   %slice_i = affine.apply #map(%idx)[%i]
 ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
 ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  ArrayRef<bool> scalables, TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter) {
+                                 AsmParser::Delimiter delimiter,
+                                 bool hasSameTypeDynamicValues) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
       printer << "[";
     if (ShapedType::isDynamic(integer)) {
       printer << values[dynamicValIdx];
-      if (!valueTypes.empty())
+      if (!hasSameTypeDynamicValues && !valueTypes.empty())
         printer << " : " << valueTypes[dynamicValIdx];
       ++dynamicValIdx;
     } else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     scalableIndexIdx++;
   });
 
+  if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+    assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+                       [&](Type type) { return type == valueTypes[0]; }) &&
+           "Expected the same value types");
+    printer << " : " << valueTypes[0];
+  }
+
   printer << rightDelimiter;
 }
 
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
-    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+    bool hasSameTypeDynamicValues) {
 
   SmallVector<int64_t, 4> integerVals;
   SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
       integerVals.push_back(ShapedType::kDynamic);
-      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+      if (!hasSameTypeDynamicValues && valueTypes &&
+          parser.parseColonType(valueTypes->emplace_back()))
         return failure();
     } else {
       int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
       return failure();
     return success();
   };
+  auto parseColonType = [&]() -> ParseResult {
+    if (hasSameTypeDynamicValues) {
+      assert(valueTypes && "Expected non-null value types");
+      assert(valueTypes->empty() && "Expected no parsed value types");
+
+      Type dynValType;
+      if (parser.parseOptionalColonType(dynValType))
+        return failure();
+
+      if (!dynValType && !values.empty())
+        return parser.emitError(parser.getNameLoc())
+               << "expected a type for dynamic indices";
+      if (dynValType) {
+        if (values.empty())
+          return parser.emitError(parser.getNameLoc())
+                 << "expected no type for constant indices";
+
+        // Broadcast the single type to all the dynamic values.
+        valueTypes->append(values.size(), dynValType);
+      }
+    }
+    return success();
+  };
   if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
-                                     " in dynamic index list"))
+                                     parseColonType, " in dynamic index list"))
     return parser.emitError(parser.getNameLoc())
-           << "expected SSA value or integer";
+           << "expected a valid list of SSA values or integers";
+
   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
   scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
   return success();
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index ff7b4bcb5f65a8..c93dbf8836f6c4 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -151,7 +151,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 // CHECK-NOT: arm_sme.store_tile_slice
 func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
   %c0 = arith.constant 0 : index
-  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
   return
 }
@@ -202,7 +202,7 @@ func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index
 {
   // CHECK-NOT: arm_sve.psel
   %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
-  %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
+  %slice = vector.extract %mask[%index : index] : vector<[32]xi1> from vector<[4]x[32]xi1>
   return %slice : vector<[32]xi1>
 }
 
@@ -215,7 +215,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index
 {
   // CHECK-NOT: arm_sve.psel
   %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
-  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
+  %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<4x[8]xi1>
   return %slice : vector<[8]xi1>
 }
 
@@ -227,7 +227,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index
 func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
 {
   // CHECK-NOT: arm_sve.psel
-  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1>
   return %slice : vector<[8]xi1>
 }
 
@@ -240,6 +240,6 @@ func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index
 {
   // CHECK-NOT: arm_sve.psel
   %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
-  %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
+  %el = vector.extract %mask[2, %index : index] : i1 from vector<[4]x[8]xi1>
   return %el : i1
 }
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 0f973af799634c..6ca19c5746ea15 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -345,7 +345,7 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
 // CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
 func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
   %c0 = arith.constant 0 : index
-  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
   return
 }
@@ -361,7 +361,7 @@ func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?
 // CHECK:         arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
 func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) {
   %c0 = arith.constant 0 : index
-  %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
   return
 }
@@ -927,7 +927,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect
   // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: arm_sme.insert_tile_slice %[[SLICE]], %[[TILE]][%[[INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xi32> into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
 
@@ -937,7 +937,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect
 func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[16]xi8> into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
 
@@ -947,7 +947,7 @@ func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vecto
 func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xi16> into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
 
@@ -957,7 +957,7 @@ func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vect
 func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xi64> into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
 
@@ -967,7 +967,7 @@ func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vect
 func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128>
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[1]xi128> into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
 
@@ -977,7 +977,7 @@ func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> ve
 func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xf16> into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
 
@@ -987,7 +987,7 @@ func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vect
 func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
 
@@ -997,7 +997,7 @@ func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> ve
 func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xf32> into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
 
@@ -1007,7 +1007,7 @@ func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vect
 func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
   // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
-  %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
+  %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xf64> into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
 
@@ -1020,10 +1020,10 @@ func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vect
 func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
   // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
-  // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
+  // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]] : index] : i32 into vector<[4]xi32>
   // CHECK-NEXT: arm_sme.insert_tile_slice %[[NEW_SLICE]], %[[TILE]][%[[ROW]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : i32 into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
 
@@ -1035,7 +1035,7 @@ func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[16]xi8> into vector<[16]x[16]xi8>
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
-  %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : i8 into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
 
@@ -1047,7 +1047,7 @@ func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vect
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xi16> into vector<[8]x[8]xi16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : i16 into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
 
@@ -1059,7 +1059,7 @@ func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vect
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xi64> into vector<[2]x[2]xi64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : i64 into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
 
@@ -1071,7 +1071,7 @@ func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> ve
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[1]xi128> into vector<[1]x[1]xi128>
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : i128 into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
 
@@ -1083,7 +1083,7 @@ func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vect
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xf16> into vector<[8]x[8]xf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
-  %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : f16 into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
 
@@ -1095,7 +1095,7 @@ func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> ve
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : bf16 into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
 
@@ -1107,7 +1107,7 @@ func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vect
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
-  %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : f32 into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
 
@@ -1119,7 +1119,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
   // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
   // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xf64> into vector<[2]x[2]xf64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
-  %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
+  %new_tile = vector.insert %el, %tile[%row, %col : index] : f64 into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
 
@@ -1135,7 +1135,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
   // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK: arm_sme.extract_tile_slice %[[TILE]][%[[INDEX]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
+  %slice = vector.extract %tile[%row : index] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
 
@@ -1145,7 +1145,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
 func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
-  %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
+  %slice = vector.extract %tile[%row : index] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
 
@@ -1155,7 +1155,7 @@ func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
 func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
+  %slice = vector.extract %tile[%row : index] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
 
@@ -1165,7 +1165,7 @@ func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
 func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
+  %slice = vector.extract %tile[%row : index] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
 
@@ -1175,7 +1175,7 @@ func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
 func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
+  %slice = vector.extract %tile[%row : index] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
 
@@ -1185,7 +1185,7 @@ func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
 func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
-  %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  %slice = vector.extract %tile[%row : index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
 
@@ -1195,7 +1195,7 @@ func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
 func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
+  %slice = vector.extract %tile[%row : index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
 
@@ -1205,7 +1205,7 @@ func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
 func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
-  %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  %slice = vector.extract %tile[%row : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
 
@@ -1215,7 +1215,7 @@ func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
 func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
   // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
-  %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = vector.extract %tile[%row : index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
 
@@ -1227,9 +1227,9 @@ func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
 func.func @vector_extract_element(%row: index, %col: index) -> i32 {
   // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32>
-  // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
+  // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]] : index] : i32 from vector<[4]xi32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
+  %el = vector.extract %tile[%row, %col : index] : i32 from vector<[4]x[4]xi32>
   return %el : i32
 }
 
@@ -1238,9 +1238,9 @@ func.func @vector_extract_element(%row: index, %col: index) -> i32 {
 // CHECK-LABEL: @vector_extract_element_i8
 func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i8 from vector<[16]xi8>
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
-  %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
+  %el = vector.extract %tile[%row, %col : index] : i8 from vector<[16]x[16]xi8>
   return %el : i8
 }
 
@@ -1249,9 +1249,9 @@ func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
 // CHECK-LABEL: @vector_extract_element_i16
 func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i16 from vector<[8]xi16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
+  %el = vector.extract %tile[%row, %col : index] : i16 from vector<[8]x[8]xi16>
   return %el : i16
 }
 
@@ -1260,9 +1260,9 @@ func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
 // CHECK-LABEL: @vector_extract_element_i64
 func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i64 from vector<[2]xi64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
+  %el = vector.extract %tile[%row, %col : index] : i64 from vector<[2]x[2]xi64>
   return %el : i64
 }
 
@@ -1271,9 +1271,9 @@ func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
 // CHECK-LABEL: @vector_extract_element_i128
 func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i128 from vector<[1]xi128>
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
+  %el = vector.extract %tile[%row, %col : index] : i128 from vector<[1]x[1]xi128>
   return %el : i128
 }
 
@@ -1282,9 +1282,9 @@ func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
 // CHECK-LABEL: @vector_extract_element_f16
 func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f16 from vector<[8]xf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
-  %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
+  %el = vector.extract %tile[%row, %col : index] : f16 from vector<[8]x[8]xf16>
   return %el : f16
 }
 
@@ -1293,9 +1293,9 @@ func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
 // CHECK-LABEL: @vector_extract_element_bf16
 func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : bf16 from vector<[8]xbf16>
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
+  %el = vector.extract %tile[%row, %col : index] : bf16 from vector<[8]x[8]xbf16>
   return %el : bf16
 }
 
@@ -1304,9 +1304,9 @@ func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
 // CHECK-LABEL: @vector_extract_element_f32
 func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f32 from vector<[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
-  %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
+  %el = vector.extract %tile[%row, %col : index] : f32 from vector<[4]x[4]xf32>
   return %el : f32
 }
 
@@ -1315,9 +1315,9 @@ func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
 // CHECK-LABEL: @vector_extract_element_f64
 func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
   // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64>
-  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
+  // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f64 from vector<[2]xf64>
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
-  %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
+  %el = vector.extract %tile[%row, %col : index] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
 
@@ -1335,7 +1335,7 @@ func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: ind
   // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
   // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
   %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
-  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1>
   return %slice : vector<[8]xi1>
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 03bcb341efea2f..953d846dceb695 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1239,7 +1239,7 @@ func.func @extract_scalar_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>)
 // -----
 
 func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: index) -> f32 {
-  %0 = vector.extract %arg0[%arg1]: f32 from vector<16xf32>
+  %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<16xf32>
   return %0 : f32
 }
 // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx
@@ -1248,7 +1248,7 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %ar
 //       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
 
 func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
-  %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32>
+  %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<[16]xf32>
   return %0 : f32
 }
 // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable
@@ -1259,7 +1259,7 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16
 // -----
 
 func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
-  %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
+  %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x16xf32>
   return %0 : f32
 }
 
@@ -1269,7 +1269,7 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %
 //       CHECK:   vector.extract
 
 func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
-  %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32>
+  %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x[16]xf32>
   return %0 : f32
 }
 
@@ -1460,7 +1460,7 @@ func.func @insert_scalar_into_vec_3d_f32_scalable(%arg0: f32, %arg1: vector<4x8x
 
 func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
                                       -> vector<16xf32> {
-  %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32>
+  %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<16xf32>
   return %0 : vector<16xf32>
 }
 
@@ -1471,7 +1471,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg
 
 func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: f32, %arg2: index)
                                       -> vector<[16]xf32> {
-  %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<[16]xf32>
+  %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<[16]xf32>
   return %0 : vector<[16]xf32>
 }
 
@@ -1484,7 +1484,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]
 
 func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: f32, %idx: index)
                                         -> vector<1x16xf32> {
-  %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x16xf32>
+  %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x16xf32>
   return %0 : vector<1x16xf32>
 }
 
@@ -1495,7 +1495,7 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
 
 func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: f32, %idx: index)
                                         -> vector<1x[16]xf32> {
-  %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x[16]xf32>
+  %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x[16]xf32>
   return %0 : vector<1x[16]xf32>
 }
 
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 5a6da3a06387a5..7d25d2b1c1e992 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -828,10 +828,10 @@ func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: mem
 // FULL-UNROLL:           %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
 // FULL-UNROLL:           scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
 // FULL-UNROLL:             %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]]
-// FULL-UNROLL:             %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
-// FULL-UNROLL:             %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
-// FULL-UNROLL:             %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
-// FULL-UNROLL:             %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32>
 // FULL-UNROLL:             %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32>
 // FULL-UNROLL:             vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8796f153c4911b..dc8272c7c82a77 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -191,7 +191,7 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 //       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
 //       CHECK:   return %[[R]]
 func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 {
-  %0 = vector.extract %arg0[%id] : f32 from vector<1xf32>
+  %0 = vector.extract %arg0[%id : index] : f32 from vector<1xf32>
   return %0: f32
 }
 
@@ -202,7 +202,7 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
 //       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
-  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  %0 = vector.extract %arg0[%id : index] : f32 from vector<4xf32>
   return %0: f32
 }
 
@@ -211,7 +211,7 @@ func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
 //       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
 func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
   %idx = arith.constant 1 : index
-  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  %0 = vector.extract %arg0[%idx : index] : f32 from vector<4xf32>
   return %0: f32
 }
 
@@ -252,7 +252,7 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 //       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
 //       CHECK:   return %[[R]]
 func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> {
-  %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32>
+  %1 = vector.insert %arg1, %arg0[%id : index] : f32 into vector<1xf32>
   return %1 : vector<1xf32>
 }
 
@@ -263,7 +263,7 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
 //       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
 //       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
-  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  %0 = vector.insert %val, %arg0[%id : index] : f32 into vector<4xf32>
   return %0: vector<4xf32>
 }
 
@@ -274,7 +274,7 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
 //       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
 func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
   %idx = arith.constant 2 : index
-  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  %0 = vector.insert %val, %arg0[%idx : index] : f32 into vector<4xf32>
   return %0: vector<4xf32>
 }
 
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 90005517835764..bac1c1cb5615eb 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -814,12 +814,12 @@ func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> {
 // CHECK-LABEL: @non_constant_extract_from_arith_ext(
 // CHECK-SAME:                                       %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>,
 // CHECK-SAME:                                       %[[DIM:[a-z0-9]+]]: index
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8>
+// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]] : index] : vector<[8]xi8> from vector<4x[8]xi8>
 // CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
 // CHECK: return %[[EXTEND]]
 func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> {
   %0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32>
-  %1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32>
+  %1 = vector.extract %0[%dim : index] : vector<[8]xi32> from vector<4x[8]xi32>
   return %1 : vector<[8]xi32>
 }
 
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 458906a1879829..61b6981b194a6a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -179,10 +179,10 @@ func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector
   // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
   // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
   // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
-  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
   // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   // CHECK-NEXT:   vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
@@ -224,20 +224,20 @@ func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0:
   // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
   // CHECK-DAG: %[[MASK:.*]] =  vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
   // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
-  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1>
   // CHECK-NEXT:   %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
-  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
-  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1>
   // CHECK-NEXT:   %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
-  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
-  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT: }
   %c0 = arith.constant 0 : index
@@ -313,16 +313,16 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
   // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
-  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT:   %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
@@ -399,7 +399,7 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind
   // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
   // CHECK-NEXT: return %[[EXTRACT]]
   %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
-  %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  %extract = vector.extract %mask[%index : index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
   return %extract : vector<[4]x[4]xi1>
 }
 
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 4e1035e038ca54..1f077409a6c666 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -734,7 +734,7 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL:  func.func @hoist_vector_broadcasts
 //       CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
-//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+//       CHECK:        %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]] : index] : vector<4xf32> from vector<3x4xf32>
 //       CHECK-NEXT:   %[[LOOP:.+]] = scf.for {{.*}} {
 //       CHECK-NEXT:     %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
 //       CHECK-NEXT:     scf.yield %[[USE]] : vector<4xf32>
@@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
   %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
-    %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+    %extract = vector.extract %iarg[%pos : index] : vector<4xf32> from vector<3x4xf32>
     %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
     %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
     scf.yield %broadcast : vector<3x4xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index fbebb97a11983e..fe108e47d5dd30 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -88,7 +88,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   %0 = transform.param.constant 2 : i64 -> !transform.param<i64>
   // expected-error at below {{expected ']' in dynamic index list}}
-  // expected-error at below {{custom op 'transform.structured.vectorize' expected SSA value or integer}}
+  // expected-error at below {{custom op 'transform.structured.vectorize' expected a valid list of SSA values or integers}}
   transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
 
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..db15a0562ef4e5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -126,7 +126,7 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
   %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
   // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
   // CHECK-NOT: vector.extract
-  %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
+  %extract = vector.extract %mask[2, %index : index] : vector<6xi1> from vector<4x4x6xi1>
   return %extract : vector<6xi1>
 }
 
@@ -140,7 +140,7 @@ func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %in
   %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
   // CHECK: arith.constant dense<false> : vector<6xi1>
   // CHECK-NOT: vector.extract
-  %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1>
+  %extract = vector.extract %mask[0, %index : index] : vector<6xi1> from vector<1x4x6xi1>
   return %extract : vector<6xi1>
 }
 
@@ -153,8 +153,8 @@ func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %inde
   %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1>
   // CHECK: %[[C2:.*]] = arith.constant 2 : index
   // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1>
-  // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1>
-  %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1>
+  // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]] : index] : vector<6xi1> from vector<4x6xi1>
+  %extract = vector.extract %mask[%index : index] : vector<6xi1> from vector<4x6xi1>
   return %extract : vector<6xi1>
 }
 
@@ -167,8 +167,8 @@ func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0
   %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1>
   // CHECK: %[[C2:.*]] = arith.constant 2 : index
   // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1>
-  // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1>
-  %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1>
+  // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]] : index] : vector<4xi1> from vector<2x4x4xi1>
+  %extract = vector.extract %mask[1, %index0 : index] : vector<4xi1> from vector<2x4x4xi1>
   return %extract : vector<4xi1>
 }
 
@@ -1918,10 +1918,10 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
 
 // CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
 //  CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
-//       CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]] : index] : vector<4xf32> from vector<2x4xf32>
 //       CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
 func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
-  %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+  %0 = vector.extract %v[%index : index] : vector<4xf32> from vector<2x4xf32>
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1 : f32
 }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..ae520c33dcb504 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -148,6 +148,39 @@ func.func @extract_vector_type(%arg0: index) {
   %1 = vector.extract %arg0[] : index from index
 }
 
+// -----
+func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>,
+                                            %i32_idx: i32, %i8_idx: i8) {
+  // expected-error at +2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}}
+  // expected-note at -2 {{prior use here}}
+  %1 = vector.extract %arg0[%i32_idx, %i8_idx : i8] : f32 from vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>,
+                                             %i32_idx: i32) {
+  // expected-error at +2 {{expected a type for dynamic indices}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.extract %arg0[%i32_idx] : f32 from vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>,
+                                                    %i8_idx : i8,
+                                                    %i32_idx : i32) {
+  // expected-error at +2 {{expected single type}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.extract %arg0[%i8_idx, %i32_idx : i8, i32] : f32 from vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_consts_type(%arg0 : vector<8x16xf32>,
+                                            %i32_idx: i32, %i8_idx: i8) {
+  // expected-error at +2 {{'vector.extract' expected no type for constant indices}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.extract %arg0[5, 3 : index] : f32 from vector<8x16xf32>
+}
+
 // -----
 
 func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) {
@@ -271,6 +304,38 @@ func.func @insert_0d(%a: f32, %b: vector<f32>) {
   %1 = vector.insert %a, %b[0] : f32 into vector<f32>
 }
 
+// -----
+func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,
+                                            %i32_idx: i32, %i8_idx: i8) {
+  // expected-error at +2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}}
+  // expected-note at -2 {{prior use here}}
+  %1 = vector.insert %arg0, %arg1[%i32_idx, %i8_idx : i8] : f32 into vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>,
+                                             %i32_idx: i32) {
+  // expected-error at +2 {{expected a type for dynamic indices}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.insert %arg0, %arg1[%i32_idx] : f32 into vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>,
+                                                    %i8_idx : i8, %i32_idx : i32) {
+  // expected-error at +2 {{expected single type}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.insert %arg0, %arg1[%i8_idx, %i32_idx : i8, i32] : f32 into vector<8x16xf32>
+}
+
+// -----
+func.func @extract_vector_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>,
+                                            %i32_idx: i32, %i8_idx: i8) {
+  // expected-error at +2 {{'vector.insert' expected no type for constant indices}}
+  // expected-error at +1 {{expected a valid list of SSA values or integers}}
+  %1 = vector.insert %arg0, %arg1[5, 3 : index] : f32 into vector<8x16xf32>
+}
+
 // -----
 
 func.func @outerproduct_num_operands(%arg0: f32) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..fb5769e7a61e7f 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -224,12 +224,26 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
 //  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
 func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
                            -> (vector<8x16xf32>, vector<16xf32>, f32) {
-  // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32>
-  %0 = vector.extract %arg0[%idx] : vector<8x16xf32> from vector<4x8x16xf32>
-  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<16xf32> from vector<4x8x16xf32>
-  %1 = vector.extract %arg0[%idx, %idx] : vector<16xf32> from vector<4x8x16xf32>
-  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : f32 from vector<4x8x16xf32>
-  %2 = vector.extract %arg0[%idx, 5, %idx] : f32 from vector<4x8x16xf32>
+  // CHECK: vector.extract %[[VEC]][%[[IDX]] : index] : vector<8x16xf32> from vector<4x8x16xf32>
+  %0 = vector.extract %arg0[%idx : index] : vector<8x16xf32> from vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]] : index] : vector<16xf32> from vector<4x8x16xf32>
+  %1 = vector.extract %arg0[%idx, %idx : index] : vector<16xf32> from vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]] : index] : f32 from vector<4x8x16xf32>
+  %2 = vector.extract %arg0[%idx, 5, %idx : index] : f32 from vector<4x8x16xf32>
+  return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32
+}
+
+// CHECK-LABEL: @extract_val_int
+//  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8
+func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,
+                           %i8_idx: i8)
+                           -> (vector<8x16xf32>, vector<16xf32>, f32) {
+  // CHECK: vector.extract %[[VEC]][%[[I32_IDX]] : i32] : vector<8x16xf32> from vector<4x8x16xf32>
+  %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> from vector<4x8x16xf32>
+  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 from vector<4x8x16xf32>
+  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>
   return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32
 }
 
@@ -274,12 +288,25 @@ func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 //  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index
 func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
                           %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
-  // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32>
-  %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32>
-  // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32>
-  %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32>
-  // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32>
-  %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]] : index] : vector<8x16xf32> into vector<4x8x16xf32>
+  %0 = vector.insert %c, %res[%idx : index] : vector<8x16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]] : index] : vector<16xf32> into vector<4x8x16xf32>
+  %1 = vector.insert %b, %res[%idx, %idx : index] : vector<16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]] : index] : f32 into vector<4x8x16xf32>
+  %2 = vector.insert %a, %res[%idx, 5, %idx : index] : f32 into vector<4x8x16xf32>
+  return %2 : vector<4x8x16xf32>
+}
+
+// CHECK-LABEL: @insert_val_int
+//  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8
+func.func @insert_val_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+                          %i32_idx: i32, %i8_idx: i8, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+  // CHECK: vector.insert %[[C]], %{{.*}}[%[[I32_IDX]] : i32] : vector<8x16xf32> into vector<4x8x16xf32>
+  %0 = vector.insert %c, %res[%i32_idx : i32] : vector<8x16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[B]], %{{.*}}[%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> into vector<4x8x16xf32>
+  %1 = vector.insert %b, %res[%i8_idx, %i8_idx : i8] : vector<16xf32> into vector<4x8x16xf32>
+  // CHECK: vector.insert %[[A]], %{{.*}}[%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 into vector<4x8x16xf32>
+  %2 = vector.insert %a, %res[%i8_idx, 5, %i8_idx : i8] : f32 into vector<4x8x16xf32>
   return %2 : vector<4x8x16xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..4bc84fcc9c31f6 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -91,13 +91,13 @@ func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector
 // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2>
 
 // -----
 
@@ -119,13 +119,13 @@ func.func @vector_load_i2_dynamic_indexing_mixed(%idx: index) -> vector<3xi2> {
 // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2>
 
 // -----
 
@@ -147,13 +147,13 @@ func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index)
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2>
 
 // -----
 
@@ -176,10 +176,10 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2>

>From c854059144d214877b19013b8dcb1f1357f210ea Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Sat, 16 Nov 2024 13:59:36 -0800
Subject: [PATCH 2/2] Feedback

---
 .../VectorToLLVM/vector-to-llvm.mlir          | 68 +++++++++++++++++++
 .../VectorToSPIRV/vector-to-spirv.mlir        | 44 ++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir         | 24 +++----
 mlir/test/Dialect/Vector/ops.mlir             | 31 +++++----
 4 files changed, 142 insertions(+), 25 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 953d846dceb695..acbf0f71b38d2d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1119,6 +1119,38 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
 //       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32>
 //       CHECK:   return {{.*}} : f32
 
+// -----
+
+func.func @extract_i32_index(%arg0: vector<16xf32>, %arg1: i32) -> f32 {
+  %0 = vector.extract %arg0[%arg1 : i32]: f32 from vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_i32_index
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i32] : vector<16xf32>
+//       CHECK:   return {{.*}} : f32
+
+// -----
+
+func.func @extract_i8_index(%arg0: vector<16xf32>, %arg1: i8) -> f32 {
+  %0 = vector.extract %arg0[%arg1 : i8]: f32 from vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_i8_index
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i8] : vector<16xf32>
+//       CHECK:   return {{.*}} : f32
+
+// -----
+
+func.func @extract_i1_index(%arg0: vector<16xf32>, %arg1: i1) -> f32 {
+  %0 = vector.extract %arg0[%arg1 : i1]: f32 from vector<16xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_i1_index
+//       CHECK:   llvm.extractelement {{.*}}[{{.*}} : i1] : vector<16xf32>
+//       CHECK:   return {{.*}} : f32
+
+// -----
+
 func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
   %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
   return %0 : f32
@@ -1247,6 +1279,8 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %ar
 //       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
 //       CHECK:   llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32>
 
+// -----
+
 func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<[16]xf32>
   return %0 : f32
@@ -1268,6 +1302,8 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %
 // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx(
 //       CHECK:   vector.extract
 
+// -----
+
 func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 {
   %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x[16]xf32>
   return %0 : f32
@@ -1356,6 +1392,38 @@ func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> ve
 //       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
 //       CHECK:   return {{.*}} : vector<4xf32>
 
+// -----
+
+func.func @insert_i32_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i32) -> vector<4xf32> {
+  %0 = vector.insert %arg0, %arg1[%arg2 : i32] : f32 into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+// CHECK-LABEL: @insert_i32_index
+//       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i32] : vector<4xf32>
+//       CHECK:   return {{.*}} : vector<4xf32>
+
+// -----
+
+func.func @insert_i8_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i8) -> vector<4xf32> {
+  %0 = vector.insert %arg0, %arg1[%arg2 : i8] : f32 into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+// CHECK-LABEL: @insert_i8_index
+//       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i8] : vector<4xf32>
+//       CHECK:   return {{.*}} : vector<4xf32>
+
+// -----
+
+func.func @insert_i1_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i1) -> vector<4xf32> {
+  %0 = vector.insert %arg0, %arg1[%arg2 : i1] : f32 into vector<4xf32>
+  return %0 : vector<4xf32>
+}
+// CHECK-LABEL: @insert_i1_index
+//       CHECK:   llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i1] : vector<4xf32>
+//       CHECK:   return {{.*}} : vector<4xf32>
+
+// -----
+
 func.func @insert_scalar_into_vec_1d_f32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> {
   %0 = vector.insert %arg0, %arg1[3] : f32 into vector<[4]xf32>
   return %0 : vector<[4]xf32>
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index dc8272c7c82a77..7b7f128c1180bc 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -206,6 +206,28 @@ func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
   return %0: f32
 }
 
+// -----
+
+// CHECK-LABEL: @extract_i32_index
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_i32_index(%arg0 : vector<4xf32>, %id : i32) -> f32 {
+  %0 = vector.extract %arg0[%id : i32] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_i8_index
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i8
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i8
+func.func @extract_i8_index(%arg0 : vector<4xf32>, %id : i8) -> f32 {
+  %0 = vector.extract %arg0[%id : i8] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @extract_dynamic_cst
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>
 //       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -269,6 +291,28 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
 
 // -----
 
+// CHECK-LABEL: @insert_i32_index
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i32
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id : i32] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_i8_index
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i8
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i8
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i8
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i8) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id : i8] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert_dynamic_cst
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
 //       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ae520c33dcb504..90a71b8e524255 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -149,7 +149,7 @@ func.func @extract_vector_type(%arg0: index) {
 }
 
 // -----
-func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>,
+func.func @extract_mixed_index_types(%arg0 : vector<8x16xf32>,
                                             %i32_idx: i32, %i8_idx: i8) {
   // expected-error at +2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}}
   // expected-note at -2 {{prior use here}}
@@ -157,7 +157,7 @@ func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>,
 }
 
 // -----
-func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>,
+func.func @extract_index_vals_no_type(%arg0 : vector<8xf32>,
                                              %i32_idx: i32) {
   // expected-error at +2 {{expected a type for dynamic indices}}
   // expected-error at +1 {{expected a valid list of SSA values or integers}}
@@ -165,7 +165,7 @@ func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>,
 }
 
 // -----
-func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>,
+func.func @extract_index_vals_multiple_types(%arg0 : vector<8xf32>,
                                                     %i8_idx : i8,
                                                     %i32_idx : i32) {
   // expected-error at +2 {{expected single type}}
@@ -174,7 +174,7 @@ func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>,
 }
 
 // -----
-func.func @extract_vector_index_consts_type(%arg0 : vector<8x16xf32>,
+func.func @extract_index_consts_type(%arg0 : vector<8x16xf32>,
                                             %i32_idx: i32, %i8_idx: i8) {
   // expected-error at +2 {{'vector.extract' expected no type for constant indices}}
   // expected-error at +1 {{expected a valid list of SSA values or integers}}
@@ -305,32 +305,32 @@ func.func @insert_0d(%a: f32, %b: vector<f32>) {
 }
 
 // -----
-func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,
-                                            %i32_idx: i32, %i8_idx: i8) {
+func.func @insert_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,
+                                    %i32_idx: i32, %i8_idx: i8) {
   // expected-error at +2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}}
   // expected-note at -2 {{prior use here}}
   %1 = vector.insert %arg0, %arg1[%i32_idx, %i8_idx : i8] : f32 into vector<8x16xf32>
 }
 
 // -----
-func.func @extract_vector_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>,
-                                             %i32_idx: i32) {
+func.func @insert_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>,
+                                     %i32_idx: i32) {
   // expected-error at +2 {{expected a type for dynamic indices}}
   // expected-error at +1 {{expected a valid list of SSA values or integers}}
   %1 = vector.insert %arg0, %arg1[%i32_idx] : f32 into vector<8x16xf32>
 }
 
 // -----
-func.func @extract_vector_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>,
-                                                    %i8_idx : i8, %i32_idx : i32) {
+func.func @insert_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>,
+                                            %i8_idx : i8, %i32_idx : i32) {
   // expected-error at +2 {{expected single type}}
   // expected-error at +1 {{expected a valid list of SSA values or integers}}
   %1 = vector.insert %arg0, %arg1[%i8_idx, %i32_idx : i8, i32] : f32 into vector<8x16xf32>
 }
 
 // -----
-func.func @extract_vector_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>,
-                                            %i32_idx: i32, %i8_idx: i8) {
+func.func @insert_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>,
+                                    %i32_idx: i32, %i8_idx: i8) {
   // expected-error at +2 {{'vector.insert' expected no type for constant indices}}
   // expected-error at +1 {{expected a valid list of SSA values or integers}}
   %1 = vector.insert %arg0, %arg1[5, 3 : index] : f32 into vector<8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index fb5769e7a61e7f..5cc2ba366febc4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -222,8 +222,8 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>)
 
 // CHECK-LABEL: @extract_val_idx
 //  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index
-func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
-                           -> (vector<8x16xf32>, vector<16xf32>, f32) {
+func.func @extract_index_as_index(%arg0: vector<4x8x16xf32>, %idx: index)
+                                  -> (vector<8x16xf32>, vector<16xf32>, f32) {
   // CHECK: vector.extract %[[VEC]][%[[IDX]] : index] : vector<8x16xf32> from vector<4x8x16xf32>
   %0 = vector.extract %arg0[%idx : index] : vector<8x16xf32> from vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]] : index] : vector<16xf32> from vector<4x8x16xf32>
@@ -234,17 +234,19 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index)
 }
 
 // CHECK-LABEL: @extract_val_int
-//  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8
-func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,
-                           %i8_idx: i8)
-                           -> (vector<8x16xf32>, vector<16xf32>, f32) {
+//  CHECK-SAME:   %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1
+func.func @extract_index_as_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,
+                                %i8_idx: i8, %i1_idx: i1)
+                           -> (vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32>) {
   // CHECK: vector.extract %[[VEC]][%[[I32_IDX]] : i32] : vector<8x16xf32> from vector<4x8x16xf32>
   %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> from vector<4x8x16xf32>
   %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
   // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 from vector<4x8x16xf32>
   %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>
-  return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32
+  // CHECK-NEXT: vector.extract %[[VEC]][%[[I1_IDX]], 2 : i1] : vector<16xf32> from vector<4x8x16xf32>
+  %3 = vector.extract %arg0[%i1_idx, 2 : i1] : vector<16xf32> from vector<4x8x16xf32>
+  return %0, %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32>
 }
 
 // CHECK-LABEL: @extract_0d
@@ -286,8 +288,8 @@ func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 
 // CHECK-LABEL: @insert_val_idx
 //  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index
-func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
-                          %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+func.func @insert_index_as_index(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+                                 %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
   // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]] : index] : vector<8x16xf32> into vector<4x8x16xf32>
   %0 = vector.insert %c, %res[%idx : index] : vector<8x16xf32> into vector<4x8x16xf32>
   // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]] : index] : vector<16xf32> into vector<4x8x16xf32>
@@ -298,16 +300,19 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_val_int
-//  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8
-func.func @insert_val_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
-                          %i32_idx: i32, %i8_idx: i8, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
+//  CHECK-SAME:   %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1
+func.func @insert_index_as_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
+                               %i32_idx: i32, %i8_idx: i8, %i1_idx: i1, %res: vector<4x8x16xf32>)
+                               -> (vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>) {
   // CHECK: vector.insert %[[C]], %{{.*}}[%[[I32_IDX]] : i32] : vector<8x16xf32> into vector<4x8x16xf32>
   %0 = vector.insert %c, %res[%i32_idx : i32] : vector<8x16xf32> into vector<4x8x16xf32>
   // CHECK: vector.insert %[[B]], %{{.*}}[%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> into vector<4x8x16xf32>
   %1 = vector.insert %b, %res[%i8_idx, %i8_idx : i8] : vector<16xf32> into vector<4x8x16xf32>
   // CHECK: vector.insert %[[A]], %{{.*}}[%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 into vector<4x8x16xf32>
   %2 = vector.insert %a, %res[%i8_idx, 5, %i8_idx : i8] : f32 into vector<4x8x16xf32>
-  return %2 : vector<4x8x16xf32>
+  // CHECK-NEXT: vector.insert %[[B]], %{{.*}}[%[[I1_IDX]], 2 : i1] : vector<16xf32> into vector<4x8x16xf32>
+  %3 = vector.insert %b, %res[%i1_idx, 2 : i1] : vector<16xf32> into vector<4x8x16xf32>
+  return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>
 }
 
 // CHECK-LABEL: @insert_0d



More information about the Mlir-commits mailing list