[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Aug 6 11:11:07 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/101850
>From c5b35e2d876ae67f81130e6045274b448b48e660 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 6 Aug 2024 15:49:31 +0000
Subject: [PATCH] [mlir][vector] Update representation of
insert/extract_strided_slice
This commit updates the representation of both extract_strided_slice and
insert_strided_slice to primitive arrays of int64_ts, rather than
ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate
conversions between IntegerAttr and int64_t.
This is done by adding a new `StridedSliceAttr` which matches the
previous syntax and can be used for both operations.
It may also be possible to explore alternate slice syntax for the
`StridedSliceAttr` in future.
---
.../Dialect/Vector/IR/VectorAttributes.td | 43 +++
.../mlir/Dialect/Vector/IR/VectorOps.td | 39 +--
.../Conversion/VectorToGPU/VectorToGPU.cpp | 11 +-
.../VectorToSPIRV/VectorToSPIRV.cpp | 13 +-
.../Dialect/Arith/Transforms/IntNarrowing.cpp | 5 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 249 +++++++-----------
.../Vector/Transforms/LowerVectorScan.cpp | 9 +-
.../Transforms/VectorDropLeadUnitDim.cpp | 24 +-
...sertExtractStridedSliceRewritePatterns.cpp | 72 ++---
.../Vector/Transforms/VectorLinearize.cpp | 19 +-
.../Vector/Transforms/VectorTransforms.cpp | 49 +---
mlir/test/Dialect/Vector/invalid.mlir | 2 +-
mlir/test/Dialect/Vector/linearize.mlir | 8 +-
13 files changed, 228 insertions(+), 315 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..e974e5bd046a9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+ : AttrDef<Vector_Dialect, attrName, traits> {
+ let mnemonic = attrMnemonic;
+}
+
// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,42 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
let assemblyFormat = "`<` $value `>`";
}
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+ let summary = "strided vector slice";
+
+ let description = [{
+ An attribute that represents a strided slice of a vector.
+
+ *Examples:*
+
+ Without sizes:
+
+ `{offsets = [0, 0, 2], strides = [1, 1]}`
+
+ With sizes (used for extract_strided_slice):
+
+ `{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}`
+
+ TODO? Come up with a range syntax (similar to Python slices).
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$offsets,
+ OptionalArrayRefParameter<"int64_t">:$sizes,
+ ArrayRefParameter<"int64_t">:$strides
+ );
+
+ let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+ return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+ }]>
+ ];
+
+ let assemblyFormat = [{
+ `{` `offsets` `=` `[` $offsets `]` `,`
+ (`sizes` `=` `[` $sizes^ `]` `,`)?
+ `strides` `=` `[` $strides `]` `}`
+ }];
+}
+
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cd19d356a6739..5f9b4f6b29b0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
- I64ArrayAttr:$strides)>,
+ Arguments<(ins AnyVector:$source, AnyVector:$dest,
+ Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector:$res)> {
let summary = "strided_slice operation";
let description = [{
@@ -1060,13 +1060,13 @@ def Vector_InsertStridedSliceOp :
```mlir
%2 = vector.insert_strided_slice %0, %1
- {offsets = [0, 0, 2], strides = [1, 1]}:
- vector<2x4xf32> into vector<16x4x8xf32>
+ {offsets = [0, 0, 2], strides = [1, 1]}
+ : vector<2x4xf32> into vector<16x4x8xf32>
```
}];
let assemblyFormat = [{
- $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+ $source `,` $dest $strided_slice attr-dict `:` type($source) `into` type($dest)
}];
let builders = [
@@ -1081,10 +1081,13 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<VectorType>(getDest().getType());
}
bool hasNonUnitStrides() {
- return llvm::any_of(getStrides(), [](Attribute attr) {
- return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+ return llvm::any_of(getStrides(), [](int64_t stride) {
+ return stride != 1;
});
}
+
+ ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+ ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];
let hasFolder = 1;
@@ -1182,8 +1185,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
- I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+ Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector)> {
let summary = "extract_strided_slice operation";
let description = [{
@@ -1201,12 +1203,8 @@ def Vector_ExtractStridedSliceOp :
```mlir
%1 = vector.extract_strided_slice %0
- {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
- vector<4x8x16xf32> to vector<2x4x16xf32>
-
- // TODO: Evolve to a range form syntax similar to:
- %1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
- vector<4x8x16xf32> to vector<2x4x16xf32>
+ {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
+ : vector<4x8x16xf32> to vector<2x4x16xf32>
```
}];
let builders = [
@@ -1217,17 +1215,20 @@ def Vector_ExtractStridedSliceOp :
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
- void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
- return llvm::any_of(getStrides(), [](Attribute attr) {
- return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+ return llvm::any_of(getStrides(), [](int64_t stride) {
+ return stride != 1;
});
}
+
+ ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+ ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+ ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+ let assemblyFormat = "$vector $strided_slice attr-dict `:` type($vector) `to` type(results)";
}
// TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index e059d31ca5842..682414c63c06a 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return success();
}
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
- SmallVectorImpl<int64_t> &results) {
- for (auto attr : arrayAttr)
- results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
static LogicalResult
convertExtractStridedSlice(RewriterBase &rewriter,
vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
auto sourceVector = it->second;
// offset and sizes at warp-level of onwership.
- SmallVector<int64_t> offsets;
- populateFromInt64AttrArray(op.getOffsets(), offsets);
+ ArrayRef<int64_t> offsets = op.getOffsets();
- SmallVector<int64_t> sizes;
- populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
// Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
return cast<IntegerAttr>(attr[0]).getInt();
}
-static uint64_t getFirstIntValue(ArrayAttr attr) {
- return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
auto attr = foldResults[0].dyn_cast<Attribute>();
if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();
- uint64_t offset = getFirstIntValue(extractOp.getOffsets());
- uint64_t size = getFirstIntValue(extractOp.getSizes());
- uint64_t stride = getFirstIntValue(extractOp.getStrides());
+ int64_t offset = extractOp.getOffsets().front();
+ int64_t size = extractOp.getSizes().front();
+ int64_t stride = extractOp.getStrides().front();
if (stride != 1)
return failure();
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
- uint64_t stride = getFirstIntValue(insertOp.getStrides());
+ uint64_t stride = insertOp.getStrides().front();
if (stride != 1)
return failure();
- uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+ uint64_t offset = insertOp.getOffsets().front();
if (isa<spirv::ScalarType>(srcVector.getType())) {
assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 70fd9bc0a1e68..2b0e6445dfda1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -549,11 +549,8 @@ struct ExtensionOverExtractStridedSlice final
if (failed(ext))
return failure();
- VectorType origTy = op.getType();
- VectorType extractTy =
- origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
- op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+ op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
op.getStrides());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a3b9f2091ab3..3a0d30098c369 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
return success();
}
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
- return llvm::to_vector<4>(llvm::map_range(
- arrayAttr.getAsRange<IntegerAttr>(),
- [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
// Trim offsets for dimensions fully extracted.
- auto sliceOffsets =
- extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+ SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
while (!sliceOffsets.empty()) {
size_t lastOffset = sliceOffsets.size() - 1;
if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertOp.getSourceVectorType().getRank();
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
- auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+ ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
- if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getInt() != 1;
- }))
+ if (insertOp.hasNonUnitStrides())
return Value();
bool disjoint = false;
SmallVector<int64_t, 4> offsetDiffs;
@@ -2195,12 +2185,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add(foldExtractFromFromElements);
}
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
- SmallVectorImpl<int64_t> &results) {
- for (auto attr : arrayAttr)
- results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
-}
-
//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//
@@ -2907,26 +2891,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) {
- result.addOperands({source, dest});
- auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
- auto stridesAttr = getVectorSubscriptAttr(builder, strides);
- result.addTypes(dest.getType());
- result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
- offsetsAttr);
- result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
- stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
- ArrayAttr arrayAttr,
- ArrayRef<int64_t> shape,
- StringRef attrName) {
- if (arrayAttr.size() > shape.size())
- return op.emitOpError("expected ")
- << attrName << " attribute of rank no greater than vector rank";
- return success();
+ build(builder, result, source, dest,
+ StridedSliceAttr::get(builder.getContext(), offsets, strides));
}
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2900,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
- int64_t max, StringRef attrName,
- bool halfOpen = true) {
- for (auto attr : arrayAttr) {
- auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+ int64_t max, StringRef arrayName,
+ bool halfOpen = true) {
+ for (int64_t val : array) {
auto upper = max;
if (!halfOpen)
upper += 1;
if (val < min || val >= upper)
- return op.emitOpError("expected ") << attrName << " to be confined to ["
+ return op.emitOpError("expected ") << arrayName << " to be confined to ["
<< min << ", " << upper << ")";
}
return success();
@@ -2954,13 +2919,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
- ArrayRef<int64_t> shape, StringRef attrName,
- bool halfOpen = true, int64_t min = 0) {
- for (auto [index, attrDimPair] :
- llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
- int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
- int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+ ArrayRef<int64_t> shape, StringRef attrName,
+ bool halfOpen = true, int64_t min = 0) {
+ for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+ int64_t val, max;
+ std::tie(val, max) = dimPair;
if (!halfOpen)
max += 1;
if (val < min || val >= max)
@@ -2977,40 +2941,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
// the admissible interval is [min, max].
template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
- OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
- ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+ OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+ ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
bool halfOpen = true, int64_t min = 1) {
- assert(arrayAttr1.size() <= shape.size());
- assert(arrayAttr2.size() <= shape.size());
- for (auto [index, it] :
- llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
- auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
- auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
- int64_t max = std::get<2>(it);
+ assert(array1.size() <= shape.size());
+ assert(array2.size() <= shape.size());
+ for (auto [index, it] : llvm::enumerate(llvm::zip(array1, array2, shape))) {
+ int64_t val1, val2, max;
+ std::tie(val1, val2, max) = it;
if (!halfOpen)
max += 1;
if (val1 + val2 < 0 || val1 + val2 >= max)
return op.emitOpError("expected sum(")
- << attrName1 << ", " << attrName2 << ") dimension " << index
+ << arrayName1 << ", " << arrayName2 << ") dimension " << index
<< " to be confined to [" << min << ", " << max << ")";
}
return success();
}
-static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
- MLIRContext *context) {
- auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
- return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
- });
- return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
-}
-
LogicalResult InsertStridedSliceOp::verify() {
auto sourceVectorType = getSourceVectorType();
auto destVectorType = getDestVectorType();
- auto offsets = getOffsetsAttr();
- auto strides = getStridesAttr();
+ auto offsets = getOffsets();
+ auto strides = getStrides();
+ if (!getStridedSlice().getSizes().empty())
+ return emitOpError("slice sizes not supported");
if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected offsets of same size as destination vector rank");
@@ -3025,18 +2981,14 @@ LogicalResult InsertStridedSliceOp::verify() {
SmallVector<int64_t, 4> sourceShapeAsDestShape(
destShape.size() - sourceShape.size(), 0);
sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
- auto offName = InsertStridedSliceOp::getOffsetsAttrName();
- auto stridesName = InsertStridedSliceOp::getStridesAttrName();
- if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
- offName)) ||
- failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
- /*max=*/1, stridesName,
- /*halfOpen=*/false)) ||
- failed(isSumOfIntegerArrayAttrConfinedToShape(
- *this, offsets,
- makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
- offName, "source vector shape",
- /*halfOpen=*/false, /*min=*/1)))
+ if (failed(isIntArrayConfinedToShape(*this, offsets, destShape, "offsets")) ||
+ failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1,
+ /*max=*/1, "strides",
+ /*halfOpen=*/false)) ||
+ failed(isSumOfIntArrayConfinedToShape(*this, offsets,
+ sourceShapeAsDestShape, destShape,
+ "offsets", "source vector shape",
+ /*halfOpen=*/false, /*min=*/1)))
return failure();
unsigned rankDiff = destShape.size() - sourceShape.size();
@@ -3161,7 +3113,7 @@ class InsertStridedSliceConstantFolder final
VectorType sliceVecTy = sourceValue.getType();
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
- SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
+ ArrayRef<int64_t> offsets = op.getOffsets();
SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
// Calcualte the destination element indices by enumerating all slice
@@ -3336,14 +3288,15 @@ Type OuterProductOp::getExpectedMaskType() {
// 2. Add sizes from 'vectorType' for remaining dims.
// Scalable flags are inherited from 'vectorType'.
static Type inferStridedSliceOpResultType(VectorType vectorType,
- ArrayAttr offsets, ArrayAttr sizes,
- ArrayAttr strides) {
+ ArrayRef<int64_t> offsets,
+ ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides) {
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
SmallVector<int64_t, 4> shape;
shape.reserve(vectorType.getRank());
unsigned idx = 0;
for (unsigned e = offsets.size(); idx < e; ++idx)
- shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
+ shape.push_back(sizes[idx]);
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);
@@ -3356,51 +3309,48 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
result.addOperands(source);
- auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
- auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
- auto stridesAttr = getVectorSubscriptAttr(builder, strides);
- result.addTypes(
- inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
- offsetsAttr, sizesAttr, stridesAttr));
- result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
- offsetsAttr);
- result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
- sizesAttr);
- result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
- stridesAttr);
+ auto stridedSliceAttr =
+ StridedSliceAttr::get(builder.getContext(), offsets, sizes, strides);
+ result.addTypes(inferStridedSliceOpResultType(
+ llvm::cast<VectorType>(source.getType()), offsets, sizes, strides));
+ result.addAttribute(
+ ExtractStridedSliceOp::getStridedSliceAttrName(result.name),
+ stridedSliceAttr);
}
LogicalResult ExtractStridedSliceOp::verify() {
auto type = getSourceVectorType();
- auto offsets = getOffsetsAttr();
- auto sizes = getSizesAttr();
- auto strides = getStridesAttr();
+ auto offsets = getOffsets();
+ auto sizes = getSizes();
+ auto strides = getStrides();
if (offsets.size() != sizes.size() || offsets.size() != strides.size())
return emitOpError(
"expected offsets, sizes and strides attributes of same size");
auto shape = type.getShape();
- auto offName = getOffsetsAttrName();
- auto sizesName = getSizesAttrName();
- auto stridesName = getStridesAttrName();
- if (failed(
- isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
- failed(
- isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
- failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
- stridesName)) ||
- failed(
- isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
- failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
- /*halfOpen=*/false,
- /*min=*/1)) ||
- failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
- /*max=*/1, stridesName,
- /*halfOpen=*/false)) ||
- failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
- shape, offName, sizesName,
- /*halfOpen=*/false)))
+ auto isIntArraySmallerThanShape = [&](ArrayRef<int64_t> array,
+ StringRef arrayName) -> LogicalResult {
+ if (array.size() > shape.size())
+ return emitOpError("expected ")
+ << arrayName << " to have rank no greater than vector rank";
+ return success();
+ };
+
+ if (failed(isIntArraySmallerThanShape(offsets, "offsets")) ||
+ failed(isIntArraySmallerThanShape(sizes, "sizes")) ||
+ failed(isIntArraySmallerThanShape(strides, "strides")) ||
+ failed(isIntArrayConfinedToShape(*this, offsets, shape, "offsets")) ||
+ failed(isIntArrayConfinedToShape(*this, sizes, shape, "sizes",
+ /*halfOpen=*/false,
+ /*min=*/1)) ||
+ failed(isIntArrayConfinedToRange(*this, strides, /*min=*/1,
+ /*max=*/1, "strides",
+ /*halfOpen=*/false)) ||
+ failed(isSumOfIntArrayConfinedToShape(*this, offsets, sizes, shape,
+ "offsets", "sizes",
+ /*halfOpen=*/false))) {
return failure();
+ }
auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
offsets, sizes, strides);
@@ -3410,7 +3360,7 @@ LogicalResult ExtractStridedSliceOp::verify() {
for (unsigned idx = 0; idx < sizes.size(); ++idx) {
if (type.getScalableDims()[idx]) {
auto inputDim = type.getShape()[idx];
- auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+ auto inputSize = sizes[idx];
if (inputDim != inputSize)
return emitOpError("expected size at idx=")
<< idx
@@ -3428,20 +3378,16 @@ LogicalResult ExtractStridedSliceOp::verify() {
// extracted vector is a subset of one of the vector inserted.
static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
- // Helper to extract integer out of ArrayAttr.
- auto getElement = [](ArrayAttr array, int idx) {
- return llvm::cast<IntegerAttr>(array[idx]).getInt();
- };
- ArrayAttr extractOffsets = op.getOffsets();
- ArrayAttr extractStrides = op.getStrides();
- ArrayAttr extractSizes = op.getSizes();
+ ArrayRef<int64_t> extractOffsets = op.getOffsets();
+ ArrayRef<int64_t> extractStrides = op.getStrides();
+ ArrayRef<int64_t> extractSizes = op.getSizes();
auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
if (op.getSourceVectorType().getRank() !=
insertOp.getSourceVectorType().getRank())
return failure();
- ArrayAttr insertOffsets = insertOp.getOffsets();
- ArrayAttr insertStrides = insertOp.getStrides();
+ ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
+ ArrayRef<int64_t> insertStrides = insertOp.getStrides();
// If the rank of extract is greater than the rank of insert, we are likely
// extracting a partial chunk of the vector inserted.
if (extractOffsets.size() > insertOffsets.size())
@@ -3450,12 +3396,12 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
bool disjoint = false;
SmallVector<int64_t, 4> offsetDiffs;
for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
- if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
+ if (extractStrides[dim] != insertStrides[dim])
return failure();
- int64_t start = getElement(insertOffsets, dim);
+ int64_t start = insertOffsets[dim];
int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
- int64_t offset = getElement(extractOffsets, dim);
- int64_t size = getElement(extractSizes, dim);
+ int64_t offset = extractOffsets[dim];
+ int64_t size = extractSizes[dim];
// Check if the start of the extract offset is in the interval inserted.
if (start <= offset && offset < end) {
// If the extract interval overlaps but is not fully included we may
@@ -3473,7 +3419,9 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
op.setOperand(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
+ auto stridedSliceAttr = StridedSliceAttr::get(
+ op.getContext(), offsetDiffs, op.getSizes(), op.getStrides());
+ op.setStridedSliceAttr(stridedSliceAttr);
return success();
}
// If the chunk extracted is disjoint from the chunk inserted, keep looking
@@ -3496,11 +3444,6 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
return getResult();
return {};
}
-
-void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(getOffsets(), results);
-}
-
namespace {
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
@@ -3524,11 +3467,8 @@ class StridedSliceConstantMaskFolder final
// Gather constant mask dimension sizes.
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
- SmallVector<int64_t, 4> sliceOffsets;
- populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
- sliceOffsets);
- SmallVector<int64_t, 4> sliceSizes;
- populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+ ArrayRef<int64_t> sliceOffsets = extractStridedSliceOp.getOffsets();
+ ArrayRef<int64_t> sliceSizes = extractStridedSliceOp.getSizes();
// Compute slice of vector mask region.
SmallVector<int64_t, 4> sliceMaskDimSizes;
@@ -3620,10 +3560,10 @@ class StridedSliceNonSplatConstantFolder final
// Expand offsets and sizes to match the vector rank.
SmallVector<int64_t, 4> offsets(sliceRank, 0);
- copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
+ copy(extractStridedSliceOp.getOffsets(), offsets.begin());
SmallVector<int64_t, 4> sizes(sourceShape);
- copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+ copy(extractStridedSliceOp.getSizes(), sizes.begin());
// Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
@@ -3686,10 +3626,9 @@ class StridedSliceBroadcast final
bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
if (!lowerDimMatch && !isScalarSrc) {
source = rewriter.create<ExtractStridedSliceOp>(
- op->getLoc(), source,
- getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+ op->getLoc(), source, op.getOffsets().drop_front(rankDiff),
+ op.getSizes().drop_front(rankDiff),
+ op.getStrides().drop_front(rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index a1f67bd0e9ed3..9d18fe74178c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -130,23 +130,16 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
VectorType initialValueType = scanOp.getInitialValueType();
int64_t initialValueRank = initialValueType.getRank();
- SmallVector<int64_t> reductionShape(destShape);
- reductionShape[reductionDim] = 1;
- VectorType reductionType = VectorType::get(reductionShape, elType);
SmallVector<int64_t> offsets(destRank, 0);
SmallVector<int64_t> strides(destRank, 1);
SmallVector<int64_t> sizes(destShape);
sizes[reductionDim] = 1;
- ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
- ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
Value lastOutput, lastInput;
for (int i = 0; i < destShape[reductionDim]; i++) {
offsets[reductionDim] = i;
- ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
Value input = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
- scanStrides);
+ loc, scanOp.getSource(), offsets, sizes, strides);
Value output;
if (i == 0) {
if (inclusive) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 42ac717b44c4b..0b8a2ab6b2fa0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -71,11 +71,6 @@ struct CastAwayExtractStridedSliceLeadingOneDim
int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
VectorType oldDstType = extractOp.getType();
- VectorType newDstType =
- VectorType::get(oldDstType.getShape().drop_front(dropCount),
- oldDstType.getElementType(),
- oldDstType.getScalableDims().drop_front(dropCount));
-
Location loc = extractOp.getLoc();
Value newSrcVector = rewriter.create<vector::ExtractOp>(
@@ -83,15 +78,12 @@ struct CastAwayExtractStridedSliceLeadingOneDim
// The offsets/sizes/strides attribute can have a less number of elements
// than the input vector's rank: it is meant for the leading dimensions.
- auto newOffsets = rewriter.getArrayAttr(
- extractOp.getOffsets().getValue().drop_front(dropCount));
- auto newSizes = rewriter.getArrayAttr(
- extractOp.getSizes().getValue().drop_front(dropCount));
- auto newStrides = rewriter.getArrayAttr(
- extractOp.getStrides().getValue().drop_front(dropCount));
+ auto newOffsets = extractOp.getOffsets().drop_front(dropCount);
+ auto newSizes = extractOp.getSizes().drop_front(dropCount);
+ auto newStrides = extractOp.getStrides().drop_front(dropCount);
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+ loc, newSrcVector, newOffsets, newSizes, newStrides);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
newExtractOp);
@@ -126,13 +118,11 @@ struct CastAwayInsertStridedSliceLeadingOneDim
Value newDstVector = rewriter.create<vector::ExtractOp>(
loc, insertOp.getDest(), splatZero(dstDropCount));
- auto newOffsets = rewriter.getArrayAttr(
- insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
- auto newStrides = rewriter.getArrayAttr(
- insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
+ auto newOffsets = insertOp.getOffsets().take_back(newDstType.getRank());
+ auto newStrides = insertOp.getStrides().take_back(newSrcType.getRank());
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
- loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+ loc, newSrcVector, newDstVector, newOffsets, newStrides);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
newInsertOp);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c..4de58ed7526a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -63,7 +63,7 @@ class DecomposeDifferentRankInsertStridedSlice
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
- if (op.getOffsets().getValue().empty())
+ if (op.getOffsets().empty())
return failure();
auto loc = op.getLoc();
@@ -76,21 +76,17 @@ class DecomposeDifferentRankInsertStridedSlice
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
Value extracted = rewriter.create<ExtractOp>(
- loc, op.getDest(),
- getI64SubArray(op.getOffsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
+ loc, op.getDest(), op.getOffsets().drop_back(rankRest));
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.getSource(), extracted,
- getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
- getI64SubArray(op.getStrides(), /*dropFront=*/0));
-
- rewriter.replaceOpWithNewOp<InsertOp>(
- op, stridedSliceInnerOp.getResult(), op.getDest(),
- getI64SubArray(op.getOffsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
+ loc, op.getSource(), extracted, op.getOffsets().drop_front(rankDiff),
+ op.getStrides());
+
+ rewriter.replaceOpWithNewOp<InsertOp>(op, stridedSliceInnerOp.getResult(),
+ op.getDest(),
+ op.getOffsets().drop_back(rankRest));
return success();
}
};
@@ -119,7 +115,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
- if (op.getOffsets().getValue().empty())
+ if (op.getOffsets().empty())
return failure();
int64_t srcRank = srcType.getRank();
@@ -133,11 +129,9 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
return success();
}
- int64_t offset =
- cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t offset = op.getOffsets().front();
int64_t size = srcType.getShape().front();
- int64_t stride =
- cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+ int64_t stride = op.getStrides().front();
auto loc = op.getLoc();
Value res = op.getDest();
@@ -181,9 +175,8 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
- loc, extractedSource, extractedDest,
- getI64SubArray(op.getOffsets(), /* dropFront=*/1),
- getI64SubArray(op.getStrides(), /* dropFront=*/1));
+ loc, extractedSource, extractedDest, op.getOffsets().drop_front(1),
+ op.getStrides().drop_front(1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -205,18 +198,16 @@ class Convert1DExtractStridedSliceIntoShuffle
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+ assert(!op.getOffsets().empty() && "Unexpected empty offsets");
- int64_t offset =
- cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
- int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
- int64_t stride =
- cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+ int64_t offset = op.getOffsets().front();
+ int64_t size = op.getSizes().front();
+ int64_t stride = op.getStrides().front();
assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
- if (op.getOffsets().getValue().size() != 1)
+ if (op.getOffsets().size() != 1)
return failure();
SmallVector<int64_t, 4> offsets;
@@ -248,14 +239,12 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
return failure();
// Only handle 1-D cases.
- if (op.getOffsets().getValue().size() != 1)
+ if (op.getOffsets().size() != 1)
return failure();
- int64_t offset =
- cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
- int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
- int64_t stride =
- cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+ int64_t offset = op.getOffsets().front();
+ int64_t size = op.getSizes().front();
+ int64_t stride = op.getStrides().front();
Location loc = op.getLoc();
SmallVector<Value> elements;
@@ -294,13 +283,11 @@ class DecomposeNDExtractStridedSlice
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+ assert(!op.getOffsets().empty() && "Unexpected empty offsets");
- int64_t offset =
- cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
- int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
- int64_t stride =
- cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
+ int64_t offset = op.getOffsets().front();
+ int64_t size = op.getSizes().front();
+ int64_t stride = op.getStrides().front();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
@@ -308,7 +295,7 @@ class DecomposeNDExtractStridedSlice
// Single offset can be more efficiently shuffled. It's handled in
// Convert1DExtractStridedSliceIntoShuffle.
- if (op.getOffsets().getValue().size() == 1)
+ if (op.getOffsets().size() == 1)
return failure();
// Extract/insert on a lower ranked extract strided slice op.
@@ -319,9 +306,8 @@ class DecomposeNDExtractStridedSlice
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.getVector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
- getI64SubArray(op.getSizes(), /* dropFront=*/1),
- getI64SubArray(op.getStrides(), /* dropFront=*/1));
+ loc, one, op.getOffsets().drop_front(), op.getSizes().drop_front(),
+ op.getStrides().drop_front());
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 868397f2daaae..dbcdc6ea8f31a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -160,10 +160,10 @@ struct LinearizeVectorExtractStridedSlice final
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
- ArrayAttr offsets = extractOp.getOffsets();
- ArrayAttr sizes = extractOp.getSizes();
- ArrayAttr strides = extractOp.getStrides();
- if (!isConstantIntValue(strides[0], 1))
+ ArrayRef<int64_t> offsets = extractOp.getOffsets();
+ ArrayRef<int64_t> sizes = extractOp.getSizes();
+ ArrayRef<int64_t> strides = extractOp.getStrides();
+ if (strides[0] != 1)
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
Value srcVector = adaptor.getVector();
@@ -185,8 +185,8 @@ struct LinearizeVectorExtractStridedSlice final
}
// Get total number of extracted slices.
int64_t nExtractedSlices = 1;
- for (Attribute size : sizes) {
- nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+ for (int64_t size : sizes) {
+ nExtractedSlices *= size;
}
// Compute the strides of the source vector considering first k dimensions.
llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
@@ -202,8 +202,7 @@ struct LinearizeVectorExtractStridedSlice final
llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
// Compute extractedStrides.
for (int i = kD - 2; i >= 0; --i) {
- extractedStrides[i] =
- extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+ extractedStrides[i] = extractedStrides[i + 1] * sizes[i + 1];
}
// Iterate over all extracted slices from 0 to nExtractedSlices - 1
// and compute the multi-dimensional index and the corresponding linearized
@@ -220,9 +219,7 @@ struct LinearizeVectorExtractStridedSlice final
// i.e. shift the multiDimIndex by the offsets.
int64_t linearizedIndex = 0;
for (int64_t j = 0; j < kD; ++j) {
- linearizedIndex +=
- (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
- sourceStrides[j];
+ linearizedIndex += (offsets[j] + multiDimIndex[j]) * sourceStrides[j];
}
// Fill the indices array form linearizedIndex to linearizedIndex +
// extractGranularitySize.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8..75820162dd9d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -548,13 +548,6 @@ struct ReorderElementwiseOpsOnTranspose final
}
};
-// Returns the values in `arrayAttr` as an integer vector.
-static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
- return llvm::to_vector<4>(
- llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
- [](IntegerAttr attr) { return attr.getInt(); }));
-}
-
// Shuffles vector.bitcast op after vector.extract op.
//
// This transforms IR like:
@@ -661,8 +654,7 @@ struct BubbleDownBitCastForStridedSliceExtract
return failure();
// Only accept all one strides for now.
- if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
- [](const APInt &val) { return !val.isOne(); }))
+ if (extractOp.hasNonUnitStrides())
return failure();
unsigned rank = extractOp.getSourceVectorType().getRank();
@@ -673,34 +665,24 @@ struct BubbleDownBitCastForStridedSliceExtract
// are selecting the full range for the last bitcasted dimension; other
// dimensions aren't affected. Otherwise, we need to scale down the last
// dimension's offset given we are extracting from less elements now.
- ArrayAttr newOffsets = extractOp.getOffsets();
+ SmallVector<int64_t> newOffsets(extractOp.getOffsets());
if (newOffsets.size() == rank) {
- SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
- if (offsets.back() % expandRatio != 0)
+ if (newOffsets.back() % expandRatio != 0)
return failure();
- offsets.back() = offsets.back() / expandRatio;
- newOffsets = rewriter.getI64ArrayAttr(offsets);
+ newOffsets.back() = newOffsets.back() / expandRatio;
}
// Similarly for sizes.
- ArrayAttr newSizes = extractOp.getSizes();
+ SmallVector<int64_t> newSizes(extractOp.getSizes());
if (newSizes.size() == rank) {
- SmallVector<int64_t> sizes = getIntValueVector(newSizes);
- if (sizes.back() % expandRatio != 0)
+ if (newSizes.back() % expandRatio != 0)
return failure();
- sizes.back() = sizes.back() / expandRatio;
- newSizes = rewriter.getI64ArrayAttr(sizes);
+ newSizes.back() = newSizes.back() / expandRatio;
}
- SmallVector<int64_t> dims =
- llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
- dims.back() = dims.back() / expandRatio;
- VectorType newExtractType =
- VectorType::get(dims, castSrcType.getElementType());
-
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
- newSizes, extractOp.getStrides());
+ extractOp.getLoc(), castOp.getSource(), newOffsets, newSizes,
+ extractOp.getStrides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
@@ -818,8 +800,7 @@ struct BubbleUpBitCastForStridedSliceInsert
return failure();
// Only accept all one strides for now.
- if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
- [](const APInt &val) { return !val.isOne(); }))
+ if (insertOp.hasNonUnitStrides())
return failure();
unsigned rank = insertOp.getSourceVectorType().getRank();
@@ -836,13 +817,11 @@ struct BubbleUpBitCastForStridedSliceInsert
if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
return failure();
- ArrayAttr newOffsets = insertOp.getOffsets();
+ SmallVector<int64_t> newOffsets(insertOp.getOffsets());
assert(newOffsets.size() == rank);
- SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
- if (offsets.back() % shrinkRatio != 0)
+ if (newOffsets.back() % shrinkRatio != 0)
return failure();
- offsets.back() = offsets.back() / shrinkRatio;
- newOffsets = rewriter.getI64ArrayAttr(offsets);
+ newOffsets.back() = newOffsets.back() / shrinkRatio;
SmallVector<int64_t> srcDims =
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
@@ -863,7 +842,7 @@ struct BubbleUpBitCastForStridedSliceInsert
bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
- bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
+ bitcastOp, newCastSrcOp, newCastDstOp, newOffsets,
insertOp.getStrides());
return success();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 10ba895a1b3a4..3d1862a9a889b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -683,7 +683,7 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
- // expected-error at +1 {{expected offsets attribute of rank no greater than vector rank}}
+ // expected-error at +1 {{op expected offsets to have rank no greater than vector rank}}
%1 = vector.extract_strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 916e3e5fd2529..93a7836b0192e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -172,18 +172,18 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
// BW-0: return %[[RES]] : vector<2x2xf32>
- %0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
+ %0 = vector.extract_strided_slice %arg0 { offsets = [0, 4], sizes = [2, 2], strides = [1, 1] }
: vector<4x8xf32> to vector<2x2xf32>
return %0 : vector<2x2xf32>
}
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
// ALL-NOT: vector.shuffle
// ALL-NOT: vector.shape_cast
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
- %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
+ %0 = vector.extract_strided_slice %arg0 { offsets = [1, 0], sizes = [2, 8], strides = [1, 1] } : vector<4x[8]xf32> to vector<2x[8]xf32>
// ALL: return %[[RES]] : vector<2x[8]xf32>
return %0 : vector<2x[8]xf32>
}
@@ -206,7 +206,7 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
// BW-0: return %[[RES]] : vector<1x4x2xf32>
- %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
+ %0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], sizes = [1, 4], strides = [1, 1] }
: vector<2x8x2xf32> to vector<1x4x2xf32>
return %0 : vector<1x4x2xf32>
}
More information about the Mlir-commits
mailing list