[Mlir-commits] [mlir] [mlir][vector] Allow integer indices in vector.extract/insert ops (PR #115808)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 11 19:33:56 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-sme

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

`vector.extract` and `vector.insert` can currently take an `i64` constant or an `index` type value as indices. The `index` type will usually lower to an `i32` or `i64` type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example:

```
  %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>
```

This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.

---

Patch is 84.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115808.diff


22 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+13-10) 
- (modified) mlir/include/mlir/IR/OpImplementation.h (+17-4) 
- (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+25-4) 
- (modified) mlir/lib/AsmParser/AsmParserImpl.h (+19-4) 
- (modified) mlir/lib/AsmParser/Parser.cpp (+7-4) 
- (modified) mlir/lib/AsmParser/Parser.h (+8-1) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+6-4) 
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-4) 
- (modified) mlir/lib/Interfaces/ViewLikeInterface.cpp (+40-6) 
- (modified) mlir/test/Conversion/VectorToArmSME/unsupported.mlir (+5-5) 
- (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+49-49) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-8) 
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+4-4) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-6) 
- (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+2-2) 
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+13-13) 
- (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+2-2) 
- (modified) mlir/test/Dialect/Linalg/transform-ops-invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+65) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+39-12) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+12-12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
     %3 = vector.extract %1[]: vector<f32> from vector<f32>
-    %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
-    %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+    %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+    %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyVectorOfAnyRank:$vector,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
 
   let assemblyFormat = [{
     $vector ``
-    custom<DynamicIndexList>($dynamic_position, $static_position)
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($result) `from` type($vector)
   }];
 
@@ -883,15 +884,15 @@ def Vector_InsertOp :
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
-    %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+    %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyType:$source,
     AnyVectorOfAnyRank:$dest,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    $source `,` $dest
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
           %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
           // Update the temporary gathered slice with the individual element
           %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
-          %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+          %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
           memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}}
     // At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
         %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
         %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
         // Here we only store to the first element in dimension one
-        %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+        %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
         memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}
     // At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
   };
 
   /// Parse a list of comma-separated items with an optional delimiter.  If a
-  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// delimiter is provided, then an empty list is allowed. If not, then at
   /// least one element will be parsed.
   ///
+  /// `parseSuffixFn` is an optional function to parse any suffix that can be
+  /// appended to the comma separated list within the delimiter.
+  ///
   /// contextMessage is an optional message appended to "expected '('" sorts of
   /// diagnostics when parsing the delimeters.
-  virtual ParseResult
+  virtual ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef()) = 0;
+  ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef()) = 0;
-
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn,
+                                   /*parseSuffixFn=*/std::nullopt,
+                                   contextMessage);
+  }
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
   ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
   virtual ParseResult
   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
+  /// Parse an optional colon followed by a type.
+  virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
 /// is non-empty, it is expected to contain as many elements as `values`
 /// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
 ///
 /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
 /// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                   OperandRange values,
                                   ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
   return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
                                delimiter);
 }
+inline void printSameTypeDynamicIndexList(
+    OpAsmPrinter &printer, Operation *op, OperandRange values,
+    ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+                               delimiter, /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Parser hook for custom directive in assemblyFormat.
 ///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline ParseResult
 parseDynamicIndexList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
   return parseDynamicIndexList(parser, values, integers, scalableVals,
                                &valueTypes, delimiter);
 }
+inline ParseResult parseSameTypeDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter,
+                               /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
-  ParseResult parseCommaSeparatedList(Delimiter delimiter,
-                                      function_ref<ParseResult()> parseElt,
-                                      StringRef contextMessage) override {
-    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElt,
+      std::optional<function_ref<ParseResult()>> parseSuffix,
+      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+                                          contextMessage);
   }
 
+  using BaseT::parseCommaSeparatedList;
+
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
     return parser.parseTypeListNoParens(result);
   }
 
+  /// Parse an optional colon followed by a type.
+  ParseResult parseOptionalColonType(Type &result) override {
+    SmallVector<Type, 1> types;
+    ParseResult parseResult = parseOptionalColonTypeList(types);
+    if (llvm::succeeded(parseResult) && types.size() > 1)
+      return emitError(getCurrentLocation(), "expected single type");
+    if (!types.empty())
+      result = types[0];
+    return parseResult;
+  }
+
   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                  bool allowDynamic,
                                  bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
 /// Parse a list of comma-separated items with an optional delimiter.  If a
 /// delimiter is provided, then an empty list is allowed.  If not, then at
 /// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
-                                function_ref<ParseResult()> parseElementFn,
-                                StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+    Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+    std::optional<function_ref<ParseResult()>> parseSuffixFn,
+    StringRef contextMessage) {
   switch (delimiter) {
   case Delimiter::None:
     break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
       return failure();
   }
 
+  if (parseSuffixFn && (*parseSuffixFn)())
+    return failure();
+
   switch (delimiter) {
   case Delimiter::None:
     return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef());
   ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef());
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+                                   contextMessage);
+  }
 
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
 ///
 /// Example:
 /// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
 /// ```
 /// Becomes:
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
 /// ```
 struct VectorExtractToArmSMELowering
     : public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
 ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
 /// ```
 struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
 /// %vscale = vector.vscale
 /// %c4_vscale = arith.muli %vscale, %c4 : index
 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
-///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
 ///   %slice_i = affine.apply #map(%idx)[%i]
 ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
 ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  ArrayRef<bool> scalables, TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter) {
+                                 AsmParser::Delimiter delimiter,
+                                 bool hasSameTypeDynamicValues) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
       printer << "[";
     if (ShapedType::isDynamic(integer)) {
       printer << values[dynamicValIdx];
-      if (!valueTypes.empty())
+      if (!hasSameTypeDynamicValues && !valueTypes.empty())
         printer << " : " << valueTypes[dynamicValIdx];
       ++dynamicValIdx;
     } else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     scalableIndexIdx++;
   });
 
+  if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+    assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+                       [&](Type type) { return type == valueTypes[0]; }) &&
+           "Expected the same value types");
+    printer << " : " << valueTypes[0];
+  }
+
   printer << rightDelimiter;
 }
 
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
-    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+    bool hasSameTypeDynamicValues) {
 
   SmallVector<int64_t, 4> integerVals;
   SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
       integerVals.push_back(ShapedType::kDynamic);
-      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+      if (!hasSameTypeDynamicValues && valueTypes &&
+          parser.parseColonType(valueTypes->emplace_back()))
         return failure();
     } else {
       int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
       return failure();
     return success();
   };
+  auto parseColonType = [&]() -> ParseResult {
+    if (hasSameTypeDynamicValues) {
+      assert(valueTypes && "Expected non-null value types");
+      assert(valueTypes->empty() && "Expected no parsed value types");
+
+      Type dynValType;
+      if (parser.parseOptionalColonType(dynValType))
+        return failure();
+
+      if (!dynValType && !values.empty())
+        return parser.emitError(parser.getNameLoc())
+               << "expected a type for dynamic indices";
+      if (dynValType) {
+        if (values.empty())
+          return parser.emitError(parser.getNameLoc())
+                 << "expected no type for constant indices";
+
+        // Broadcast the single type to all the dynamic values.
+        valueTypes->append(values.size(), dynValType);
+   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/115808


More information about the Mlir-commits mailing list