[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 3 14:10:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This commit updates the representation of both `extract_strided_slice` and `insert_strided_slice` to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t.

Because previously the offsets, strides, and sizes were in the attribute dictionary (with no special syntax), simply replacing the attribute types with `DenseI64ArrayAttr` would be a syntax break.

So since a syntax break is mostly unavoidable this commit also tackles a long-standing TODO:

```mlir
// TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
  : vector<4x8x16xf32> to vector<2x4x16xf32>
```

This is done by introducing a new `StridedSliceAttr` attribute that can be used for both operations, with syntax based on the above example (see the attribute documentation `VectorAttributes.td` for a full syntax overview).


With this:

`extract_strided_slice` goes from:
```mlir
%1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
     : vector<4x8x16xf32> to vector<2x4x16xf32>
```
To:
```mlir
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
     : vector<4x8x16xf32> to vector<2x4x16xf32>
```
(matching the TODO)

---

And `insert_strided_slice` goes from:
```mlir
%2 = vector.insert_strided_slice %0, %1
     {offsets = [0, 0, 2], strides = [1, 1]}
     : vector<2x4xf32> into vector<16x4x8xf32>
```

To:
```mlir
%2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
     : vector<2x4xf32> into vector<16x4x8xf32>
```
(inspired by the TODO)

---

Almost all test changes were done automatically via `auto-upgrade-insert-extract-slice.py`, available at: https://gist.github.com/MacDue/ca84d3ec19cf83ae71aab2be8f09c3c5 (use at your own risk).

This PR is split into multiple commits to make the changes more understandable.
- The first commit is code changes
- The second commit is **automatic** test changes
- The final commit is manual test changes




---

Patch is 354.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101850.diff


44 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td (+64) 
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+19-20) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-10) 
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+5-8) 
- (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-4) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+184-149) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (+1-8) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+7-17) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+29-43) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-11) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+14-35) 
- (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1) 
- (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+15-15) 
- (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+54-54) 
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1) 
- (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+4-4) 
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+12-12) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10-10) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+4-4) 
- (modified) mlir/test/Dialect/Arith/emulate-wide-int.mlir (+19-19) 
- (modified) mlir/test/Dialect/Arith/int-narrowing.mlir (+16-16) 
- (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+83-83) 
- (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+2-2) 
- (modified) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir (+10-10) 
- (modified) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+19-19) 
- (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+96-96) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+63-63) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+17-24) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+7-9) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir (+12-12) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+10-10) 
- (modified) mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/vector-scan-transforms.mlir (+20-20) 
- (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+4-4) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-unroll.mlir (+102-102) 
- (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+36-36) 
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+70-70) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/contraction.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/extract-strided-slice.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/insert-strided-slice.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/transpose.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..7fa20b950e7c6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
 include "mlir/Dialect/Vector/IR/Vector.td"
 include "mlir/IR/EnumAttr.td"
 
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<Vector_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 // The "kind" of combining function for contractions and reductions.
 def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
 def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,63 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
   let assemblyFormat = "`<` $value `>`";
 }
 
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+  let summary = "strided vector slice";
+
+  let description = [{
+    An attribute that represents a strided slice of a vector.
+
+    *Syntax:*
+
+    ```
+    offset = integer-literal
+    stride = integer-literal
+    size = integer-literal
+    offset-list = offset (`,` offset)*
+
+    // Without sizes (used for insert_strided_slice)
+    strided-slice-without-sizes = offset-list? (`[` offset `:` stride `]`)+
+
+    // With sizes (used for extract_strided_slice)
+    strided-slice-with-sizes = (`[` offset `:` size `:` stride `]`)+
+    ```
+
+    *Examples:*
+
+    Without sizes:
+
+    `[0:1][4:2]`
+
+    - The first dimension starts at offset 0 and is strided by 1
+    - The second dimension starts at offset 4 and is strided by 2
+
+    `[0, 1, 2][3:1][4:8]`
+
+    - The first three dimensions are indexed without striding (offsets 0, 1, 2)
+    - The fourth dimension starts at offset 3 and is strided by 1
+    - The fifth dimension starts at offset 4 and is strided by 8
+
+    With sizes (used for extract_strided_slice)
+
+    `[0:2:4][2:4:3]`
+
+    - The first dimension starts at offset 0, has size 2, and is strided by 4
+    - The second dimension starts at offset 2, has size 4, and is strided by 3
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$offsets,
+    OptionalArrayRefParameter<"int64_t">:$sizes,
+    ArrayRefParameter<"int64_t">:$strides
+  );
+
+  let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+      return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+    }]>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..45edb75c1989a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$source, AnyVector:$dest,
+               Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector:$res)> {
   let summary = "strided_slice operation";
   let description = [{
@@ -1059,14 +1059,13 @@ def Vector_InsertStridedSliceOp :
     Example:
 
     ```mlir
-    %2 = vector.insert_strided_slice %0, %1
-        {offsets = [0, 0, 2], strides = [1, 1]}:
-      vector<2x4xf32> into vector<16x4x8xf32>
+    %2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
+      : vector<2x4xf32> into vector<16x4x8xf32>
     ```
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $source `,` $dest `` $strided_slice attr-dict `:` type($source) `into` type($dest)
   }];
 
   let builders = [
@@ -1081,10 +1080,13 @@ def Vector_InsertStridedSliceOp :
       return ::llvm::cast<VectorType>(getDest().getType());
     }
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
 
   let hasFolder = 1;
@@ -1298,8 +1300,7 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector)> {
   let summary = "extract_strided_slice operation";
   let description = [{
@@ -1316,13 +1317,8 @@ def Vector_ExtractStridedSliceOp :
     Example:
 
     ```mlir
-    %1 = vector.extract_strided_slice %0
-        {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
-      vector<4x8x16xf32> to vector<2x4x16xf32>
-
-    // TODO: Evolve to a range form syntax similar to:
     %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
-      vector<4x8x16xf32> to vector<2x4x16xf32>
+      : vector<4x8x16xf32> to vector<2x4x16xf32>
     ```
   }];
   let builders = [
@@ -1333,17 +1329,20 @@ def Vector_ExtractStridedSliceOp :
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
-    void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+  let assemblyFormat = "$vector `` $strided_slice attr-dict `:` type($vector) `to` type(results)";
 }
 
 // TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 0150ff667e4ef..a2647e2b647c1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
   return success();
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
 static LogicalResult
 convertExtractStridedSlice(RewriterBase &rewriter,
                            vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
   auto sourceVector = it->second;
 
   // offset and sizes at warp-level of onwership.
-  SmallVector<int64_t> offsets;
-  populateFromInt64AttrArray(op.getOffsets(), offsets);
+  ArrayRef<int64_t> offsets = op.getOffsets();
 
-  SmallVector<int64_t> sizes;
-  populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
   return cast<IntegerAttr>(attr[0]).getInt();
 }
-static uint64_t getFirstIntValue(ArrayAttr attr) {
-  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
   auto attr = foldResults[0].dyn_cast<Attribute>();
   if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
     if (!dstType)
       return failure();
 
-    uint64_t offset = getFirstIntValue(extractOp.getOffsets());
-    uint64_t size = getFirstIntValue(extractOp.getSizes());
-    uint64_t stride = getFirstIntValue(extractOp.getStrides());
+    int64_t offset = extractOp.getOffsets().front();
+    int64_t size = extractOp.getSizes().front();
+    int64_t stride = extractOp.getStrides().front();
     if (stride != 1)
       return failure();
 
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
     Value srcVector = adaptor.getOperands().front();
     Value dstVector = adaptor.getOperands().back();
 
-    uint64_t stride = getFirstIntValue(insertOp.getStrides());
+    uint64_t stride = insertOp.getStrides().front();
     if (stride != 1)
       return failure();
-    uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+    uint64_t offset = insertOp.getOffsets().front();
 
     if (isa<spirv::ScalarType>(srcVector.getType())) {
       assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index e2d42e961c576..941644e1116fc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -550,11 +550,8 @@ struct ExtensionOverExtractStridedSlice final
     if (failed(ext))
       return failure();
 
-    VectorType origTy = op.getType();
-    VectorType extractTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+        op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
         op.getStrides());
     ext->recreateAndReplace(rewriter, op, newExtract);
     return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..dda6b916176fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
   return success();
 }
 
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
-  return llvm::to_vector<4>(llvm::map_range(
-      arrayAttr.getAsRange<IntegerAttr>(),
-      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
 /// Fold the result of chains of ExtractOp in place by simply concatenating the
 /// positions.
 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     return Value();
 
   // Trim offsets for dimensions fully extracted.
-  auto sliceOffsets =
-      extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+  SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
   while (!sliceOffsets.empty()) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                              insertOp.getSourceVectorType().getRank();
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
-    auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+    ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
     ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
 
-    if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getInt() != 1;
-        }))
+    if (insertOp.hasNonUnitStrides())
       return Value();
     bool disjoint = false;
     SmallVector<int64_t, 4> offsetDiffs;
@@ -2899,6 +2889,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+  SmallVector<int64_t> offsets;
+  SmallVector<int64_t> sizes;
+  SmallVector<int64_t> strides;
+  bool parsedNonStridedOffsets = false;
+  while (succeeded(parser.parseOptionalLSquare())) {
+    int64_t offset = 0;
+    if (parser.parseInteger(offset))
+      return {};
+    if (parser.parseOptionalColon()) {
+      // Case 1: [Offset, ...]
+      if (!strides.empty() || parsedNonStridedOffsets) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "expected slice stride or size");
+        return {};
+      }
+      offsets.push_back(offset);
+      if (succeeded(parser.parseOptionalComma())) {
+        if (parser.parseCommaSeparatedList(
+                AsmParser::Delimiter::None, [&]() -> ParseResult {
+                  if (parser.parseInteger(offset))
+                    return failure();
+                  offsets.push_back(offset);
+                  return success();
+                })) {
+          return {};
+        }
+      }
+      if (parser.parseRSquare())
+        return {};
+      parsedNonStridedOffsets = true;
+      continue;
+    }
+    int64_t sizeOrStide = 0;
+    if (parser.parseInteger(sizeOrStide)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected slice stride or size");
+      return {};
+    }
+    if (parser.parseOptionalColon()) {
+      // Case 2: [Offset:Stride]
+      if (!sizes.empty() || parser.parseRSquare()) {
+        parser.emitError(parser.getCurrentLocation(), "expected slice size");
+        return {};
+      }
+      offsets.push_back(offset);
+      strides.push_back(sizeOrStide);
+      continue;
+    }
+    // Case 3: [Offset:Size:Stride]
+    if (sizes.size() < strides.size()) {
+      parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+      return {};
+    }
+    int64_t stride = 0;
+    if (parser.parseInteger(stride) || parser.parseRSquare()) {
+      parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+      return {};
+    }
+    offsets.push_back(offset);
+    sizes.push_back(sizeOrStide);
+    strides.push_back(stride);
+  }
+  return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+  ArrayRef<int64_t> offsets = getOffsets();
+  ArrayRef<int64_t> sizes = getSizes();
+  ArrayRef<int64_t> strides = getStrides();
+  int nonStridedOffsets = offsets.size() - strides.size();
+  if (nonStridedOffsets > 0) {
+    printer << '[';
+    llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+    printer << ']';
+  }
+  for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+    int strideIdx = d - nonStridedOffsets;
+    printer << '[' << offsets[d] << ':';
+    if (!sizes.empty())
+      printer << sizes[strideIdx] << ':';
+    printer << strides[strideIdx] << ']';
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // InsertStridedSliceOp
 //===----------------------------------------------------------------------===//
@@ -2907,26 +2986,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                  Value source, Value dest,
                                  ArrayRef<int64_t> offsets,
                                  ArrayRef<int64_t> strides) {
-  result.addOperands({source, dest});
-  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
-  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
-  result.addTypes(dest.getType());
-  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
-                      offsetsAttr);
-  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
-                      stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
-                                                        ArrayAttr arrayAttr,
-                                                        ArrayRef<int64_t> shape,
-                                                        StringRef attrName) {
-  if (arrayAttr.size() > shape.size())
-    return op.emitOpError("expected ")
-           << attrName << " attribute of rank no greater than vector rank";
-  return success();
+  build(builder, result, source, dest,
+        StridedSliceAttr::get(builder.getContext(), offsets, strides));
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2995,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
-                                  int64_t max, StringRef attrName,
-                                  bool halfOpen = true) {
-  for (auto attr : arrayAttr) {
-    auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+                          int64_t max, StringRef arrayName,
+                          bool halfOpen = true) {
+  for (int64_t val : array) {
     auto upper = max;
     if (!halfOpen)
       upper += 1;
     if (val < min || val >= upper)
-      return op.emitOpError("expected ") << attrName << " to be confined to ["
+      return op.emitOpError("expected ") << arrayName << " to be confined to ["
                                          << min << ", " << upper << ")";
   }
   return success();
@@ -2954,13 +3014,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
-                                  ArrayRef<int64_t> shape, StringRef attrName,
-                                  bool halfOpen = true, int64_t min = 0) {
-  for (auto [index, attrDimPair] :
-       llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
-    int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
-    int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+                          ArrayRef<int64_t> shape, StringRef attrName,
+                          bool halfOpen = true, int64_t min = 0) {
+  for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+    int64_t val, max;
+    std::tie(val, max) = dimPair;
     if (!halfOpen)
       max += 1;
     if (val < min || val >= max)
@@ -2977,40 +3036,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
 // the admissible interval is [min, max].
 template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
-    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
-    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+    OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+    ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
     bool halfOpen = true, int64_t min = 1) {
-  assert(arrayAttr1.size() <= shape.size());
-  assert(arrayAttr2.size() <= sh...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/101850


More information about the Mlir-commits mailing list