[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Aug 6 11:11:07 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/101850

>From c5b35e2d876ae67f81130e6045274b448b48e660 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 6 Aug 2024 15:49:31 +0000
Subject: [PATCH] [mlir][vector] Update representation of
 insert/extract_strided_slice

This commit updates the representation of both extract_strided_slice and
insert_strided_slice to primitive arrays of int64_ts, rather than
ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate
conversions between IntegerAttr and int64_t.

This is done by adding a new `StridedSliceAttr` which matches the
previous syntax and can be used for both operations.

It may also be possible to explore alternate slice syntax for the
`StridedSliceAttr` in future.
---
 .../Dialect/Vector/IR/VectorAttributes.td     |  43 +++
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  39 +--
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  11 +-
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  13 +-
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |   5 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 249 +++++++-----------
 .../Vector/Transforms/LowerVectorScan.cpp     |   9 +-
 .../Transforms/VectorDropLeadUnitDim.cpp      |  24 +-
 ...sertExtractStridedSliceRewritePatterns.cpp |  72 ++---
 .../Vector/Transforms/VectorLinearize.cpp     |  19 +-
 .../Vector/Transforms/VectorTransforms.cpp    |  49 +---
 mlir/test/Dialect/Vector/invalid.mlir         |   2 +-
 mlir/test/Dialect/Vector/linearize.mlir       |   8 +-
 13 files changed, 228 insertions(+), 315 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..e974e5bd046a9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
 include "mlir/Dialect/Vector/IR/Vector.td"
 include "mlir/IR/EnumAttr.td"
 
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<Vector_Dialect, attrName, traits> {
+  let mnemonic = attrMnemonic;
+}
+
 // The "kind" of combining function for contractions and reductions.
 def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
 def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,42 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
   let assemblyFormat = "`<` $value `>`";
 }
 
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+  let summary = "strided vector slice";
+
+  let description = [{
+    An attribute that represents a strided slice of a vector.
+
+    *Examples:*
+
+    Without sizes:
+
+    `{offsets = [0, 0, 2], strides = [1, 1]}`
+
+    With sizes (used for extract_strided_slice):
+
+    `{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}`
+
+    TODO? Come up with a range syntax (similar to Python slices).
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$offsets,
+    OptionalArrayRefParameter<"int64_t">:$sizes,
+    ArrayRefParameter<"int64_t">:$strides
+  );
+
+  let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+      return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+    }]>
+  ];
+
+  let assemblyFormat = [{
+    `{` `offsets` `=` `[` $offsets `]` `,`
+    (`sizes` `=` `[` $sizes^ `]` `,`)?
+    `strides` `=` `[` $strides `]` `}`
+  }];
+}
+
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cd19d356a6739..5f9b4f6b29b0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$source, AnyVector:$dest,
+               Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector:$res)> {
   let summary = "strided_slice operation";
   let description = [{
@@ -1060,13 +1060,13 @@ def Vector_InsertStridedSliceOp :
 
     ```mlir
     %2 = vector.insert_strided_slice %0, %1
-        {offsets = [0, 0, 2], strides = [1, 1]}:
-      vector<2x4xf32> into vector<16x4x8xf32>
+        {offsets = [0, 0, 2], strides = [1, 1]}
+        : vector<2x4xf32> into vector<16x4x8xf32>
     ```
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+    $source `,` $dest $strided_slice attr-dict `:` type($source) `into` type($dest)
   }];
 
   let builders = [
@@ -1081,10 +1081,13 @@ def Vector_InsertStridedSliceOp :
       return ::llvm::cast<VectorType>(getDest().getType());
     }
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
 
   let hasFolder = 1;
@@ -1182,8 +1185,7 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
-               I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+    Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
     Results<(outs AnyVector)> {
   let summary = "extract_strided_slice operation";
   let description = [{
@@ -1201,12 +1203,8 @@ def Vector_ExtractStridedSliceOp :
 
     ```mlir
     %1 = vector.extract_strided_slice %0
-        {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
-      vector<4x8x16xf32> to vector<2x4x16xf32>
-
-    // TODO: Evolve to a range form syntax similar to:
-    %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
-      vector<4x8x16xf32> to vector<2x4x16xf32>
+        {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
+        : vector<4x8x16xf32> to vector<2x4x16xf32>
     ```
   }];
   let builders = [
@@ -1217,17 +1215,20 @@ def Vector_ExtractStridedSliceOp :
     VectorType getSourceVectorType() {
       return ::llvm::cast<VectorType>(getVector().getType());
     }
-    void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
-      return llvm::any_of(getStrides(), [](Attribute attr) {
-        return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+      return llvm::any_of(getStrides(), [](int64_t stride) {
+        return stride != 1;
       });
     }
+
+    ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+    ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+    ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+  let assemblyFormat = "$vector $strided_slice attr-dict `:` type($vector) `to` type(results)";
 }
 
 // TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index e059d31ca5842..682414c63c06a 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
   return success();
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
 static LogicalResult
 convertExtractStridedSlice(RewriterBase &rewriter,
                            vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
   auto sourceVector = it->second;
 
   // offset and sizes at warp-level of onwership.
-  SmallVector<int64_t> offsets;
-  populateFromInt64AttrArray(op.getOffsets(), offsets);
+  ArrayRef<int64_t> offsets = op.getOffsets();
 
-  SmallVector<int64_t> sizes;
-  populateFromInt64AttrArray(op.getSizes(), sizes);
   ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
 static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
   return cast<IntegerAttr>(attr[0]).getInt();
 }
-static uint64_t getFirstIntValue(ArrayAttr attr) {
-  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
 static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
   auto attr = foldResults[0].dyn_cast<Attribute>();
   if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
     if (!dstType)
       return failure();
 
-    uint64_t offset = getFirstIntValue(extractOp.getOffsets());
-    uint64_t size = getFirstIntValue(extractOp.getSizes());
-    uint64_t stride = getFirstIntValue(extractOp.getStrides());
+    int64_t offset = extractOp.getOffsets().front();
+    int64_t size = extractOp.getSizes().front();
+    int64_t stride = extractOp.getStrides().front();
     if (stride != 1)
       return failure();
 
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
     Value srcVector = adaptor.getOperands().front();
     Value dstVector = adaptor.getOperands().back();
 
-    uint64_t stride = getFirstIntValue(insertOp.getStrides());
+    uint64_t stride = insertOp.getStrides().front();
     if (stride != 1)
       return failure();
-    uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+    uint64_t offset = insertOp.getOffsets().front();
 
     if (isa<spirv::ScalarType>(srcVector.getType())) {
       assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 70fd9bc0a1e68..2b0e6445dfda1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final
     if (failed(ext))
       return failure();
 
-    VectorType origTy = op.getType();
-    VectorType extractTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+        op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
         op.getStrides());
     ext->recreateAndReplace(rewriter, op, newExtract);
     return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a3b9f2091ab3..3a0d30098c369 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
   return success();
 }
 
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
-  return llvm::to_vector<4>(llvm::map_range(
-      arrayAttr.getAsRange<IntegerAttr>(),
-      [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
 /// Fold the result of chains of ExtractOp in place by simply concatenating the
 /// positions.
 static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     return Value();
 
   // Trim offsets for dimensions fully extracted.
-  auto sliceOffsets =
-      extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+  SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
   while (!sliceOffsets.empty()) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
                              insertOp.getSourceVectorType().getRank();
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
-    auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+    ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
     ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
 
-    if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getInt() != 1;
-        }))
+    if (insertOp.hasNonUnitStrides())
       return Value();
     bool disjoint = false;
     SmallVector<int64_t, 4> offsetDiffs;
@@ -2195,12 +2185,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add(foldExtractFromFromElements);
 }
 
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
-                                       SmallVectorImpl<int64_t> &results) {
-  for (auto attr : arrayAttr)
-    results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
-}
-
 //===----------------------------------------------------------------------===//
 // FmaOp
 //===----------------------------------------------------------------------===//
@@ -2907,26 +2891,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                  Value source, Value dest,
                                  ArrayRef<int64_t> offsets,
                                  ArrayRef<int64_t> strides) {
-  result.addOperands({source, dest});
-  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
-  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
-  result.addTypes(dest.getType());
-  result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
-                      offsetsAttr);
-  result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
-                      stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
-                                                        ArrayAttr arrayAttr,
-                                                        ArrayRef<int64_t> shape,
-                                                        StringRef attrName) {
-  if (arrayAttr.size() > shape.size())
-    return op.emitOpError("expected ")
-           << attrName << " attribute of rank no greater than vector rank";
-  return success();
+  build(builder, result, source, dest,
+        StridedSliceAttr::get(builder.getContext(), offsets, strides));
 }
 
 // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2900,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
-                                  int64_t max, StringRef attrName,
-                                  bool halfOpen = true) {
-  for (auto attr : arrayAttr) {
-    auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+                          int64_t max, StringRef arrayName,
+                          bool halfOpen = true) {
+  for (int64_t val : array) {
     auto upper = max;
     if (!halfOpen)
       upper += 1;
     if (val < min || val >= upper)
-      return op.emitOpError("expected ") << attrName << " to be confined to ["
+      return op.emitOpError("expected ") << arrayName << " to be confined to ["
                                          << min << ", " << upper << ")";
   }
   return success();
@@ -2954,13 +2919,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
 // Otherwise, the admissible interval is [min, max].
 template <typename OpType>
 static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
-                                  ArrayRef<int64_t> shape, StringRef attrName,
-                                  bool halfOpen = true, int64_t min = 0) {
-  for (auto [index, attrDimPair] :
-       llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
-    int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
-    int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+                          ArrayRef<int64_t> shape, StringRef attrName,
+                          bool halfOpen = true, int64_t min = 0) {
+  for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+    int64_t val, max;
+    std::tie(val, max) = dimPair;
     if (!halfOpen)
       max += 1;
     if (val < min || val >= max)
@@ -2977,40 +2941,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
 // If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
 // the admissible interval is [min, max].
 template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
-    OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
-    ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+    OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+    ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
     bool halfOpen = true, int64_t min = 1) {
-  assert(arrayAttr1.size() <= shape.size());
-  assert(arrayAttr2.size() <= shape.size());
-  for (auto [index, it] :
-       llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
-    auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
-    auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
-    int64_t max = std::get<2>(it);
+  assert(array1.size() <= shape.size());
+  assert(array2.size() <= shape.size());
+  for (auto [index, it] : llvm::enumerate(llvm::zip(array1, array2, shape))) {
+    int64_t val1, val2, max;
+    std::tie(val1, val2, max) = it;
     if (!halfOpen)
       max += 1;
     if (val1 + val2 < 0 || val1 + val2 >= max)
       return op.emitOpError("expected sum(")
-             << attrName1 << ", " << attrName2 << ") dimension " << index
+             << arrayName1 << ", " << arrayName2 << ") dimension " << index
              << " to be confined to [" << min << ", " << max << ")";
   }
   return success();
 }
 
-static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
-                                  MLIRContext *context) {
-  auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
-    return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
-  });
-  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
-}
-
 LogicalResult InsertStridedSliceOp::verify() {
   auto sourceVectorType = getSourceVectorType();
   auto destVectorType = getDestVectorType();
-  auto offsets = getOffsetsAttr();
-  auto strides = getStridesAttr();
+  auto offsets = getOffsets();
+  auto strides = getStrides();
+  if (!getStridedSlice().getSizes().empty())
+    return emitOpError("slice sizes not supported");
   if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
     return emitOpError(
         "expected offsets of same size as destination vector rank");
@@ -3025,18 +2981,14 @@ LogicalResult InsertStridedSliceOp::verify() {
   SmallVector<int64_t, 4> sourceShapeAsDestShape(
       destShape.size() - sourceShape.size(), 0);
   sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
-  auto offName = InsertStridedSliceOp::getOffsetsAttrName();
-  auto stridesName = InsertStridedSliceOp::getStridesAttrName();
-  if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
-                                               offName)) ||
-      failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
-                                               /*max=*/1, stridesName,
-                                               /*halfOpen=*/false)) ||
-      failed(isSumOfIntegerArrayAttrConfinedToShape(
-          *this, offsets,
-          makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
-          offName, "source vector shape",
-          /*halfOpen=*/false, /*min=*/1)))
+  if (failed(isIntArrayConfinedToShape(*this, offsets, destShape, "offsets")) ||
+      failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1,
+                                       /*max=*/1, "strides",
+                                       /*halfOpen=*/false)) ||
+      failed(isSumOfIntArrayConfinedToShape(*this, offsets,
+                                            sourceShapeAsDestShape, destShape,
+                                            "offsets", "source vector shape",
+                                            /*halfOpen=*/false, /*min=*/1)))
     return failure();
 
   unsigned rankDiff = destShape.size() - sourceShape.size();
@@ -3161,7 +3113,7 @@ class InsertStridedSliceConstantFolder final
     VectorType sliceVecTy = sourceValue.getType();
     ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
     int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
-    SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
+    ArrayRef<int64_t> offsets = op.getOffsets();
     SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
 
     // Calcualte the destination element indices by enumerating all slice
@@ -3336,14 +3288,15 @@ Type OuterProductOp::getExpectedMaskType() {
 //   2. Add sizes from 'vectorType' for remaining dims.
 // Scalable flags are inherited from 'vectorType'.
 static Type inferStridedSliceOpResultType(VectorType vectorType,
-                                          ArrayAttr offsets, ArrayAttr sizes,
-                                          ArrayAttr strides) {
+                                          ArrayRef<int64_t> offsets,
+                                          ArrayRef<int64_t> sizes,
+                                          ArrayRef<int64_t> strides) {
   assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
   SmallVector<int64_t, 4> shape;
   shape.reserve(vectorType.getRank());
   unsigned idx = 0;
   for (unsigned e = offsets.size(); idx < e; ++idx)
-    shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
+    shape.push_back(sizes[idx]);
   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
     shape.push_back(vectorType.getShape()[idx]);
 
@@ -3356,51 +3309,48 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                                   ArrayRef<int64_t> sizes,
                                   ArrayRef<int64_t> strides) {
   result.addOperands(source);
-  auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
-  auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
-  auto stridesAttr = getVectorSubscriptAttr(builder, strides);
-  result.addTypes(
-      inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
-                                    offsetsAttr, sizesAttr, stridesAttr));
-  result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
-                      offsetsAttr);
-  result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
-                      sizesAttr);
-  result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
-                      stridesAttr);
+  auto stridedSliceAttr =
+      StridedSliceAttr::get(builder.getContext(), offsets, sizes, strides);
+  result.addTypes(inferStridedSliceOpResultType(
+      llvm::cast<VectorType>(source.getType()), offsets, sizes, strides));
+  result.addAttribute(
+      ExtractStridedSliceOp::getStridedSliceAttrName(result.name),
+      stridedSliceAttr);
 }
 
 LogicalResult ExtractStridedSliceOp::verify() {
   auto type = getSourceVectorType();
-  auto offsets = getOffsetsAttr();
-  auto sizes = getSizesAttr();
-  auto strides = getStridesAttr();
+  auto offsets = getOffsets();
+  auto sizes = getSizes();
+  auto strides = getStrides();
   if (offsets.size() != sizes.size() || offsets.size() != strides.size())
     return emitOpError(
         "expected offsets, sizes and strides attributes of same size");
 
   auto shape = type.getShape();
-  auto offName = getOffsetsAttrName();
-  auto sizesName = getSizesAttrName();
-  auto stridesName = getStridesAttrName();
-  if (failed(
-          isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
-      failed(
-          isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
-      failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
-                                                stridesName)) ||
-      failed(
-          isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
-      failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
-                                               /*halfOpen=*/false,
-                                               /*min=*/1)) ||
-      failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
-                                               /*max=*/1, stridesName,
-                                               /*halfOpen=*/false)) ||
-      failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
-                                                    shape, offName, sizesName,
-                                                    /*halfOpen=*/false)))
+  auto isIntArraySmallerThanShape = [&](ArrayRef<int64_t> array,
+                                        StringRef arrayName) -> LogicalResult {
+    if (array.size() > shape.size())
+      return emitOpError("expected ")
+             << arrayName << " to have rank no greater than vector rank";
+    return success();
+  };
+
+  if (failed(isIntArraySmallerThanShape(offsets, "offsets")) ||
+      failed(isIntArraySmallerThanShape(sizes, "sizes")) ||
+      failed(isIntArraySmallerThanShape(strides, "strides")) ||
+      failed(isIntArrayConfinedToShape(*this, offsets, shape, "offsets")) ||
+      failed(isIntArrayConfinedToShape(*this, sizes, shape, "sizes",
+                                       /*halfOpen=*/false,
+                                       /*min=*/1)) ||
+      failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1,
+                                       /*max=*/1, "strides",
+                                       /*halfOpen=*/false)) ||
+      failed(isSumOfIntArrayConfinedToShape(*this, offsets, sizes, shape,
+                                            "offsets", "sizes",
+                                            /*halfOpen=*/false))) {
     return failure();
+  }
 
   auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
                                                   offsets, sizes, strides);
@@ -3410,7 +3360,7 @@ LogicalResult ExtractStridedSliceOp::verify() {
   for (unsigned idx = 0; idx < sizes.size(); ++idx) {
     if (type.getScalableDims()[idx]) {
       auto inputDim = type.getShape()[idx];
-      auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+      auto inputSize = sizes[idx];
       if (inputDim != inputSize)
         return emitOpError("expected size at idx=")
                << idx
@@ -3428,20 +3378,16 @@ LogicalResult ExtractStridedSliceOp::verify() {
 // extracted vector is a subset of one of the vector inserted.
 static LogicalResult
 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
-  // Helper to extract integer out of ArrayAttr.
-  auto getElement = [](ArrayAttr array, int idx) {
-    return llvm::cast<IntegerAttr>(array[idx]).getInt();
-  };
-  ArrayAttr extractOffsets = op.getOffsets();
-  ArrayAttr extractStrides = op.getStrides();
-  ArrayAttr extractSizes = op.getSizes();
+  ArrayRef<int64_t> extractOffsets = op.getOffsets();
+  ArrayRef<int64_t> extractStrides = op.getStrides();
+  ArrayRef<int64_t> extractSizes = op.getSizes();
   auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
   while (insertOp) {
     if (op.getSourceVectorType().getRank() !=
         insertOp.getSourceVectorType().getRank())
       return failure();
-    ArrayAttr insertOffsets = insertOp.getOffsets();
-    ArrayAttr insertStrides = insertOp.getStrides();
+    ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
+    ArrayRef<int64_t> insertStrides = insertOp.getStrides();
     // If the rank of extract is greater than the rank of insert, we are likely
     // extracting a partial chunk of the vector inserted.
     if (extractOffsets.size() > insertOffsets.size())
@@ -3450,12 +3396,12 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
     bool disjoint = false;
     SmallVector<int64_t, 4> offsetDiffs;
     for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
-      if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
+      if (extractStrides[dim] != insertStrides[dim])
         return failure();
-      int64_t start = getElement(insertOffsets, dim);
+      int64_t start = insertOffsets[dim];
       int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
-      int64_t offset = getElement(extractOffsets, dim);
-      int64_t size = getElement(extractSizes, dim);
+      int64_t offset = extractOffsets[dim];
+      int64_t size = extractSizes[dim];
       // Check if the start of the extract offset is in the interval inserted.
       if (start <= offset && offset < end) {
         // If the extract interval overlaps but is not fully included we may
@@ -3473,7 +3419,9 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
       op.setOperand(insertOp.getSource());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
       OpBuilder b(op.getContext());
-      op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
+      auto stridedSliceAttr = StridedSliceAttr::get(
+          op.getContext(), offsetDiffs, op.getSizes(), op.getStrides());
+      op.setStridedSliceAttr(stridedSliceAttr);
       return success();
     }
     // If the chunk extracted is disjoint from the chunk inserted, keep looking
@@ -3496,11 +3444,6 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getResult();
   return {};
 }
-
-void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
-  populateFromInt64AttrArray(getOffsets(), results);
-}
-
 namespace {
 
 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
@@ -3524,11 +3467,8 @@ class StridedSliceConstantMaskFolder final
     // Gather constant mask dimension sizes.
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     // Gather strided slice offsets and sizes.
-    SmallVector<int64_t, 4> sliceOffsets;
-    populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
-                               sliceOffsets);
-    SmallVector<int64_t, 4> sliceSizes;
-    populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+    ArrayRef<int64_t> sliceOffsets = extractStridedSliceOp.getOffsets();
+    ArrayRef<int64_t> sliceSizes = extractStridedSliceOp.getSizes();
 
     // Compute slice of vector mask region.
     SmallVector<int64_t, 4> sliceMaskDimSizes;
@@ -3620,10 +3560,10 @@ class StridedSliceNonSplatConstantFolder final
 
     // Expand offsets and sizes to match the vector rank.
     SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
+    copy(extractStridedSliceOp.getOffsets(), offsets.begin());
 
     SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+    copy(extractStridedSliceOp.getSizes(), sizes.begin());
 
     // Calculate the slice elements by enumerating all slice positions and
     // linearizing them. The enumeration order is lexicographic which yields a
@@ -3686,10 +3626,9 @@ class StridedSliceBroadcast final
     bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
     if (!lowerDimMatch && !isScalarSrc) {
       source = rewriter.create<ExtractStridedSliceOp>(
-          op->getLoc(), source,
-          getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+          op->getLoc(), source, op.getOffsets().drop_front(rankDiff),
+          op.getSizes().drop_front(rankDiff),
+          op.getStrides().drop_front(rankDiff));
     }
     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index a1f67bd0e9ed3..9d18fe74178c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -130,23 +130,16 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
     VectorType initialValueType = scanOp.getInitialValueType();
     int64_t initialValueRank = initialValueType.getRank();
 
-    SmallVector<int64_t> reductionShape(destShape);
-    reductionShape[reductionDim] = 1;
-    VectorType reductionType = VectorType::get(reductionShape, elType);
     SmallVector<int64_t> offsets(destRank, 0);
     SmallVector<int64_t> strides(destRank, 1);
     SmallVector<int64_t> sizes(destShape);
     sizes[reductionDim] = 1;
-    ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
-    ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
 
     Value lastOutput, lastInput;
     for (int i = 0; i < destShape[reductionDim]; i++) {
       offsets[reductionDim] = i;
-      ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
-          scanStrides);
+          loc, scanOp.getSource(), offsets, sizes, strides);
       Value output;
       if (i == 0) {
         if (inclusive) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 42ac717b44c4b..0b8a2ab6b2fa0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -71,11 +71,6 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
 
     VectorType oldDstType = extractOp.getType();
-    VectorType newDstType =
-        VectorType::get(oldDstType.getShape().drop_front(dropCount),
-                        oldDstType.getElementType(),
-                        oldDstType.getScalableDims().drop_front(dropCount));
-
     Location loc = extractOp.getLoc();
 
     Value newSrcVector = rewriter.create<vector::ExtractOp>(
@@ -83,15 +78,12 @@ struct CastAwayExtractStridedSliceLeadingOneDim
 
     // The offsets/sizes/strides attribute can have a less number of elements
     // than the input vector's rank: it is meant for the leading dimensions.
-    auto newOffsets = rewriter.getArrayAttr(
-        extractOp.getOffsets().getValue().drop_front(dropCount));
-    auto newSizes = rewriter.getArrayAttr(
-        extractOp.getSizes().getValue().drop_front(dropCount));
-    auto newStrides = rewriter.getArrayAttr(
-        extractOp.getStrides().getValue().drop_front(dropCount));
+    auto newOffsets = extractOp.getOffsets().drop_front(dropCount);
+    auto newSizes = extractOp.getSizes().drop_front(dropCount);
+    auto newStrides = extractOp.getStrides().drop_front(dropCount);
 
     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
-        loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+        loc, newSrcVector, newOffsets, newSizes, newStrides);
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
                                                      newExtractOp);
@@ -126,13 +118,11 @@ struct CastAwayInsertStridedSliceLeadingOneDim
     Value newDstVector = rewriter.create<vector::ExtractOp>(
         loc, insertOp.getDest(), splatZero(dstDropCount));
 
-    auto newOffsets = rewriter.getArrayAttr(
-        insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
-    auto newStrides = rewriter.getArrayAttr(
-        insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
+    auto newOffsets = insertOp.getOffsets().take_back(newDstType.getRank());
+    auto newStrides = insertOp.getStrides().take_back(newSrcType.getRank());
 
     auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
-        loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+        loc, newSrcVector, newDstVector, newOffsets, newStrides);
 
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
                                                      newInsertOp);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c..4de58ed7526a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -63,7 +63,7 @@ class DecomposeDifferentRankInsertStridedSlice
     auto srcType = op.getSourceVectorType();
     auto dstType = op.getDestVectorType();
 
-    if (op.getOffsets().getValue().empty())
+    if (op.getOffsets().empty())
       return failure();
 
     auto loc = op.getLoc();
@@ -76,21 +76,17 @@ class DecomposeDifferentRankInsertStridedSlice
     // Extract / insert the subvector of matching rank and InsertStridedSlice
     // on it.
     Value extracted = rewriter.create<ExtractOp>(
-        loc, op.getDest(),
-        getI64SubArray(op.getOffsets(), /*dropFront=*/0,
-                       /*dropBack=*/rankRest));
+        loc, op.getDest(), op.getOffsets().drop_back(rankRest));
 
     // A different pattern will kick in for InsertStridedSlice with matching
     // ranks.
     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
-        loc, op.getSource(), extracted,
-        getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
-        getI64SubArray(op.getStrides(), /*dropFront=*/0));
-
-    rewriter.replaceOpWithNewOp<InsertOp>(
-        op, stridedSliceInnerOp.getResult(), op.getDest(),
-        getI64SubArray(op.getOffsets(), /*dropFront=*/0,
-                       /*dropBack=*/rankRest));
+        loc, op.getSource(), extracted, op.getOffsets().drop_front(rankDiff),
+        op.getStrides());
+
+    rewriter.replaceOpWithNewOp<InsertOp>(op, stridedSliceInnerOp.getResult(),
+                                          op.getDest(),
+                                          op.getOffsets().drop_back(rankRest));
     return success();
   }
 };
@@ -119,7 +115,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
     auto srcType = op.getSourceVectorType();
     auto dstType = op.getDestVectorType();
 
-    if (op.getOffsets().getValue().empty())
+    if (op.getOffsets().empty())
       return failure();
 
     int64_t srcRank = srcType.getRank();
@@ -133,11 +129,9 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
       return success();
     }
 
-    int64_t offset =
-        cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+    int64_t offset = op.getOffsets().front();
     int64_t size = srcType.getShape().front();
-    int64_t stride =
-        cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+    int64_t stride = op.getStrides().front();
 
     auto loc = op.getLoc();
     Value res = op.getDest();
@@ -181,9 +175,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
         // smaller rank.
         extractedSource = rewriter.create<InsertStridedSliceOp>(
-            loc, extractedSource, extractedDest,
-            getI64SubArray(op.getOffsets(), /* dropFront=*/1),
-            getI64SubArray(op.getStrides(), /* dropFront=*/1));
+            loc, extractedSource, extractedDest, op.getOffsets().drop_front(1),
+            op.getStrides().drop_front(1));
       }
       // 4. Insert the extractedSource into the res vector.
       res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -205,18 +198,16 @@ class Convert1DExtractStridedSliceIntoShuffle
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
 
-    assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+    assert(!op.getOffsets().empty() && "Unexpected empty offsets");
 
-    int64_t offset =
-        cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
-    int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
-    int64_t stride =
-        cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+    int64_t offset = op.getOffsets().front();
+    int64_t size = op.getSizes().front();
+    int64_t stride = op.getStrides().front();
 
     assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
 
     // Single offset can be more efficiently shuffled.
-    if (op.getOffsets().getValue().size() != 1)
+    if (op.getOffsets().size() != 1)
       return failure();
 
     SmallVector<int64_t, 4> offsets;
@@ -248,14 +239,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
       return failure();
 
     // Only handle 1-D cases.
-    if (op.getOffsets().getValue().size() != 1)
+    if (op.getOffsets().size() != 1)
       return failure();
 
-    int64_t offset =
-        cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
-    int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
-    int64_t stride =
-        cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+    int64_t offset = op.getOffsets().front();
+    int64_t size = op.getSizes().front();
+    int64_t stride = op.getStrides().front();
 
     Location loc = op.getLoc();
     SmallVector<Value> elements;
@@ -294,13 +283,11 @@ class DecomposeNDExtractStridedSlice
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
 
-    assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+    assert(!op.getOffsets().empty() && "Unexpected empty offsets");
 
-    int64_t offset =
-        cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
-    int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
-    int64_t stride =
-        cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+    int64_t offset = op.getOffsets().front();
+    int64_t size = op.getSizes().front();
+    int64_t stride = op.getStrides().front();
 
     auto loc = op.getLoc();
     auto elemType = dstType.getElementType();
@@ -308,7 +295,7 @@ class DecomposeNDExtractStridedSlice
 
     // Single offset can be more efficiently shuffled. It's handled in
     // Convert1DExtractStridedSliceIntoShuffle.
-    if (op.getOffsets().getValue().size() == 1)
+    if (op.getOffsets().size() == 1)
       return failure();
 
     // Extract/insert on a lower ranked extract strided slice op.
@@ -319,9 +306,8 @@ class DecomposeNDExtractStridedSlice
          off += stride, ++idx) {
       Value one = extractOne(rewriter, loc, op.getVector(), off);
       Value extracted = rewriter.create<ExtractStridedSliceOp>(
-          loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
-          getI64SubArray(op.getSizes(), /* dropFront=*/1),
-          getI64SubArray(op.getStrides(), /* dropFront=*/1));
+          loc, one, op.getOffsets().drop_front(), op.getSizes().drop_front(),
+          op.getStrides().drop_front());
       res = insertOne(rewriter, loc, extracted, res, idx);
     }
     rewriter.replaceOp(op, res);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 868397f2daaae..dbcdc6ea8f31a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -160,10 +160,10 @@ struct LinearizeVectorExtractStridedSlice final
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
-    ArrayAttr offsets = extractOp.getOffsets();
-    ArrayAttr sizes = extractOp.getSizes();
-    ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
+    ArrayRef<int64_t> offsets = extractOp.getOffsets();
+    ArrayRef<int64_t> sizes = extractOp.getSizes();
+    ArrayRef<int64_t> strides = extractOp.getStrides();
+    if (strides[0] != 1)
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
     Value srcVector = adaptor.getVector();
@@ -185,8 +185,8 @@ struct LinearizeVectorExtractStridedSlice final
     }
     // Get total number of extracted slices.
     int64_t nExtractedSlices = 1;
-    for (Attribute size : sizes) {
-      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+    for (int64_t size : sizes) {
+      nExtractedSlices *= size;
     }
     // Compute the strides of the source vector considering first k dimensions.
     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
@@ -202,8 +202,7 @@ struct LinearizeVectorExtractStridedSlice final
     llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
     // Compute extractedStrides.
     for (int i = kD - 2; i >= 0; --i) {
-      extractedStrides[i] =
-          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+      extractedStrides[i] = extractedStrides[i + 1] * sizes[i + 1];
     }
     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
     // and compute the multi-dimensional index and the corresponding linearized
@@ -220,9 +219,7 @@ struct LinearizeVectorExtractStridedSlice final
       // i.e. shift the multiDimIndex by the offsets.
       int64_t linearizedIndex = 0;
       for (int64_t j = 0; j < kD; ++j) {
-        linearizedIndex +=
-            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
-            sourceStrides[j];
+        linearizedIndex += (offsets[j] + multiDimIndex[j]) * sourceStrides[j];
       }
       // Fill the indices array form linearizedIndex to linearizedIndex +
       // extractGranularitySize.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8..75820162dd9d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -548,13 +548,6 @@ struct ReorderElementwiseOpsOnTranspose final
   }
 };
 
-// Returns the values in `arrayAttr` as an integer vector.
-static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
-                      [](IntegerAttr attr) { return attr.getInt(); }));
-}
-
 // Shuffles vector.bitcast op after vector.extract op.
 //
 // This transforms IR like:
@@ -661,8 +654,7 @@ struct BubbleDownBitCastForStridedSliceExtract
       return failure();
 
     // Only accept all one strides for now.
-    if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
-                     [](const APInt &val) { return !val.isOne(); }))
+    if (extractOp.hasNonUnitStrides())
       return failure();
 
     unsigned rank = extractOp.getSourceVectorType().getRank();
@@ -673,34 +665,24 @@ struct BubbleDownBitCastForStridedSliceExtract
     // are selecting the full range for the last bitcasted dimension; other
     // dimensions aren't affected. Otherwise, we need to scale down the last
     // dimension's offset given we are extracting from less elements now.
-    ArrayAttr newOffsets = extractOp.getOffsets();
+    SmallVector<int64_t> newOffsets(extractOp.getOffsets());
     if (newOffsets.size() == rank) {
-      SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
-      if (offsets.back() % expandRatio != 0)
+      if (newOffsets.back() % expandRatio != 0)
         return failure();
-      offsets.back() = offsets.back() / expandRatio;
-      newOffsets = rewriter.getI64ArrayAttr(offsets);
+      newOffsets.back() = newOffsets.back() / expandRatio;
     }
 
     // Similarly for sizes.
-    ArrayAttr newSizes = extractOp.getSizes();
+    SmallVector<int64_t> newSizes(extractOp.getSizes());
     if (newSizes.size() == rank) {
-      SmallVector<int64_t> sizes = getIntValueVector(newSizes);
-      if (sizes.back() % expandRatio != 0)
+      if (newSizes.back() % expandRatio != 0)
         return failure();
-      sizes.back() = sizes.back() / expandRatio;
-      newSizes = rewriter.getI64ArrayAttr(sizes);
+      newSizes.back() = newSizes.back() / expandRatio;
     }
 
-    SmallVector<int64_t> dims =
-        llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
-    dims.back() = dims.back() / expandRatio;
-    VectorType newExtractType =
-        VectorType::get(dims, castSrcType.getElementType());
-
     auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
-        extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
-        newSizes, extractOp.getStrides());
+        extractOp.getLoc(), castOp.getSource(), newOffsets, newSizes,
+        extractOp.getStrides());
 
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(
         extractOp, extractOp.getType(), newExtractOp);
@@ -818,8 +800,7 @@ struct BubbleUpBitCastForStridedSliceInsert
       return failure();
 
     // Only accept all one strides for now.
-    if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
-                     [](const APInt &val) { return !val.isOne(); }))
+    if (insertOp.hasNonUnitStrides())
       return failure();
 
     unsigned rank = insertOp.getSourceVectorType().getRank();
@@ -836,13 +817,11 @@ struct BubbleUpBitCastForStridedSliceInsert
     if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
       return failure();
 
-    ArrayAttr newOffsets = insertOp.getOffsets();
+    SmallVector<int64_t> newOffsets(insertOp.getOffsets());
     assert(newOffsets.size() == rank);
-    SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
-    if (offsets.back() % shrinkRatio != 0)
+    if (newOffsets.back() % shrinkRatio != 0)
       return failure();
-    offsets.back() = offsets.back() / shrinkRatio;
-    newOffsets = rewriter.getI64ArrayAttr(offsets);
+    newOffsets.back() = newOffsets.back() / shrinkRatio;
 
     SmallVector<int64_t> srcDims =
         llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
@@ -863,7 +842,7 @@ struct BubbleUpBitCastForStridedSliceInsert
         bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
 
     rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
-        bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
+        bitcastOp, newCastSrcOp, newCastDstOp, newOffsets,
         insertOp.getStrides());
 
     return success();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 10ba895a1b3a4..3d1862a9a889b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -683,7 +683,7 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
 // -----
 
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
-  // expected-error at +1 {{expected offsets attribute of rank no greater than vector rank}}
+  // expected-error at +1 {{op expected offsets to have rank no greater than vector rank}}
   %1 = vector.extract_strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 916e3e5fd2529..93a7836b0192e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -172,18 +172,18 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
   // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
   // BW-0: return %[[RES]] : vector<2x2xf32>
-  %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
+  %0 = vector.extract_strided_slice %arg0 { offsets = [0, 4], sizes = [2, 2], strides = [1, 1] }
      : vector<4x8xf32> to vector<2x2xf32>
   return %0 : vector<2x2xf32>
 }
 
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {  
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
   // ALL-NOT: vector.shuffle
   // ALL-NOT: vector.shape_cast
   // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
-  %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
+  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 0], sizes = [2, 8], strides = [1, 1] } : vector<4x[8]xf32> to vector<2x[8]xf32>
   // ALL: return %[[RES]] : vector<2x[8]xf32>
   return %0 : vector<2x[8]xf32>
 }
@@ -206,7 +206,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 
   // BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
   // BW-0: return %[[RES]] : vector<1x4x2xf32>
-  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
+  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], sizes = [1, 4], strides = [1, 1] }
     : vector<2x8x2xf32> to vector<1x4x2xf32>
   return %0 : vector<1x4x2xf32>
 }



More information about the Mlir-commits mailing list