[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 3 14:10:22 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This commit updates the representation of both `extract_strided_slice` and `insert_strided_slice` to primitive arrays of int64_ts, rather than ArrayAttrs of IntegerAttrs. This prevents a lot of boilerplate conversions between IntegerAttr and int64_t.
Because previously the offsets, strides, and sizes were in the attribute dictionary (with no special syntax), simply replacing the attribute types with `DenseI64ArrayAttr` would be a syntax break.
So since a syntax break is mostly unavoidable this commit also tackles a long-standing TODO:
```mlir
// TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
: vector<4x8x16xf32> to vector<2x4x16xf32>
```
This is done by introducing a new `StridedSliceAttr` attribute that can be used for both operations, with syntax based on the above example (see the attribute documentation `VectorAttributes.td` for a full syntax overview).
With this:
`extract_strided_slice` goes from:
```mlir
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}
: vector<4x8x16xf32> to vector<2x4x16xf32>
```
To:
```mlir
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
: vector<4x8x16xf32> to vector<2x4x16xf32>
```
(matching the TODO)
---
And `insert_strided_slice` goes from:
```mlir
%2 = vector.insert_strided_slice %0, %1
{offsets = [0, 0, 2], strides = [1, 1]}
: vector<2x4xf32> into vector<16x4x8xf32>
```
To:
```mlir
%2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
: vector<2x4xf32> into vector<16x4x8xf32>
```
(inspired by the TODO)
---
Almost all test changes were done automatically via `auto-upgrade-insert-extract-slice.py`, available at: https://gist.github.com/MacDue/ca84d3ec19cf83ae71aab2be8f09c3c5 (use at your own risk).
This PR is split into multiple commits to make the changes more understandable.
- The first commit is code changes
- The second commit is **automatic** test changes
- The final commit is manual test changes
---
Patch is 354.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101850.diff
44 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td (+64)
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+19-20)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-10)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+5-8)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-4)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+184-149)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (+1-8)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+7-17)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+29-43)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+8-11)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+14-35)
- (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+1-1)
- (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+15-15)
- (modified) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+54-54)
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+1-1)
- (modified) mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (+4-4)
- (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir (+12-12)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10-10)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+4-4)
- (modified) mlir/test/Dialect/Arith/emulate-wide-int.mlir (+19-19)
- (modified) mlir/test/Dialect/Arith/int-narrowing.mlir (+16-16)
- (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+83-83)
- (modified) mlir/test/Dialect/ArmNeon/roundtrip.mlir (+2-2)
- (modified) mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir (+6-6)
- (modified) mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir (+10-10)
- (modified) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+19-19)
- (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+96-96)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+63-63)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+17-24)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+7-9)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+8-8)
- (modified) mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir (+12-12)
- (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+8-8)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+10-10)
- (modified) mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir (+2-2)
- (modified) mlir/test/Dialect/Vector/vector-scan-transforms.mlir (+20-20)
- (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+4-4)
- (modified) mlir/test/Dialect/Vector/vector-transfer-unroll.mlir (+102-102)
- (modified) mlir/test/Dialect/Vector/vector-transforms.mlir (+36-36)
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+70-70)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/contraction.mlir (+2-2)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/extract-strided-slice.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/insert-strided-slice.mlir (+4-4)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/transpose.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
index 0f08f61d7b257..7fa20b950e7c6 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
@@ -16,6 +16,11 @@
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"
+class Vector_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
+ : AttrDef<Vector_Dialect, attrName, traits> {
+ let mnemonic = attrMnemonic;
+}
+
// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
@@ -82,4 +87,63 @@ def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctu
let assemblyFormat = "`<` $value `>`";
}
+def Vector_StridedSliceAttr : Vector_Attr<"StridedSlice", "strided_slice">
+{
+ let summary = "strided vector slice";
+
+ let description = [{
+ An attribute that represents a strided slice of a vector.
+
+ *Syntax:*
+
+ ```
+ offset = integer-literal
+ stride = integer-literal
+ size = integer-literal
+ offset-list = offset (`,` offset)*
+
+ // Without sizes (used for insert_strided_slice)
+ strided-slice-without-sizes = offset-list? (`[` offset `:` stride `]`)+
+
+ // With sizes (used for extract_strided_slice)
+ strided-slice-with-sizes = (`[` offset `:` size `:` stride `]`)+
+ ```
+
+ *Examples:*
+
+ Without sizes:
+
+ `[0:1][4:2]`
+
+ - The first dimension starts at offset 0 and is strided by 1
+ - The second dimension starts at offset 4 and is strided by 2
+
+ `[0, 1, 2][3:1][4:8]`
+
+ - The first three dimensions are indexed without striding (offsets 0, 1, 2)
+ - The fourth dimension starts at offset 3 and is strided by 1
+ - The fifth dimension starts at offset 4 and is strided by 8
+
+ With sizes (used for extract_strided_slice)
+
+ `[0:2:4][2:4:3]`
+
+ - The first dimension starts at offset 0, has size 2, and is strided by 4
+ - The second dimension starts at offset 2, has size 4, and is strided by 3
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$offsets,
+ OptionalArrayRefParameter<"int64_t">:$sizes,
+ ArrayRefParameter<"int64_t">:$strides
+ );
+
+ let builders = [AttrBuilder<(ins "ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides), [{
+ return $_get($_ctxt, offsets, ArrayRef<int64_t>{}, strides);
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..45edb75c1989a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1040,8 +1040,8 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
- I64ArrayAttr:$strides)>,
+ Arguments<(ins AnyVector:$source, AnyVector:$dest,
+ Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector:$res)> {
let summary = "strided_slice operation";
let description = [{
@@ -1059,14 +1059,13 @@ def Vector_InsertStridedSliceOp :
Example:
```mlir
- %2 = vector.insert_strided_slice %0, %1
- {offsets = [0, 0, 2], strides = [1, 1]}:
- vector<2x4xf32> into vector<16x4x8xf32>
+ %2 = vector.insert_strided_slice %0, %1[0][0:1][2:1]
+ : vector<2x4xf32> into vector<16x4x8xf32>
```
}];
let assemblyFormat = [{
- $source `,` $dest attr-dict `:` type($source) `into` type($dest)
+ $source `,` $dest `` $strided_slice attr-dict `:` type($source) `into` type($dest)
}];
let builders = [
@@ -1081,10 +1080,13 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<VectorType>(getDest().getType());
}
bool hasNonUnitStrides() {
- return llvm::any_of(getStrides(), [](Attribute attr) {
- return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+ return llvm::any_of(getStrides(), [](int64_t stride) {
+ return stride != 1;
});
}
+
+ ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+ ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];
let hasFolder = 1;
@@ -1298,8 +1300,7 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
- I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
+ Arguments<(ins AnyVector:$vector, Vector_StridedSliceAttr:$strided_slice)>,
Results<(outs AnyVector)> {
let summary = "extract_strided_slice operation";
let description = [{
@@ -1316,13 +1317,8 @@ def Vector_ExtractStridedSliceOp :
Example:
```mlir
- %1 = vector.extract_strided_slice %0
- {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
- vector<4x8x16xf32> to vector<2x4x16xf32>
-
- // TODO: Evolve to a range form syntax similar to:
%1 = vector.extract_strided_slice %0[0:2:1][2:4:1]
- vector<4x8x16xf32> to vector<2x4x16xf32>
+ : vector<4x8x16xf32> to vector<2x4x16xf32>
```
}];
let builders = [
@@ -1333,17 +1329,20 @@ def Vector_ExtractStridedSliceOp :
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
- void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
- return llvm::any_of(getStrides(), [](Attribute attr) {
- return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
+ return llvm::any_of(getStrides(), [](int64_t stride) {
+ return stride != 1;
});
}
+
+ ArrayRef<int64_t> getOffsets() { return getStridedSlice().getOffsets(); }
+ ArrayRef<int64_t> getSizes() { return getStridedSlice().getSizes(); }
+ ArrayRef<int64_t> getStrides() { return getStridedSlice().getStrides(); }
}];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
+ let assemblyFormat = "$vector `` $strided_slice attr-dict `:` type($vector) `to` type(results)";
}
// TODO: Tighten semantics so that masks and inbounds can't be used
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 0150ff667e4ef..a2647e2b647c1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -940,12 +940,6 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
return success();
}
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
- SmallVectorImpl<int64_t> &results) {
- for (auto attr : arrayAttr)
- results.push_back(cast<IntegerAttr>(attr).getInt());
-}
-
static LogicalResult
convertExtractStridedSlice(RewriterBase &rewriter,
vector::ExtractStridedSliceOp op,
@@ -996,11 +990,8 @@ convertExtractStridedSlice(RewriterBase &rewriter,
auto sourceVector = it->second;
// offset and sizes at warp-level of onwership.
- SmallVector<int64_t> offsets;
- populateFromInt64AttrArray(op.getOffsets(), offsets);
+ ArrayRef<int64_t> offsets = op.getOffsets();
- SmallVector<int64_t> sizes;
- populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
// Compute offset in vector registers. Note that the mma.sync vector registers
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 21b8858989839..4d4e5ebb4f428 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -46,9 +46,6 @@ static uint64_t getFirstIntValue(ValueRange values) {
static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
return cast<IntegerAttr>(attr[0]).getInt();
}
-static uint64_t getFirstIntValue(ArrayAttr attr) {
- return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
-}
static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
auto attr = foldResults[0].dyn_cast<Attribute>();
if (attr)
@@ -187,9 +184,9 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();
- uint64_t offset = getFirstIntValue(extractOp.getOffsets());
- uint64_t size = getFirstIntValue(extractOp.getSizes());
- uint64_t stride = getFirstIntValue(extractOp.getStrides());
+ int64_t offset = extractOp.getOffsets().front();
+ int64_t size = extractOp.getSizes().front();
+ int64_t stride = extractOp.getStrides().front();
if (stride != 1)
return failure();
@@ -323,10 +320,10 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
- uint64_t stride = getFirstIntValue(insertOp.getStrides());
+ uint64_t stride = insertOp.getStrides().front();
if (stride != 1)
return failure();
- uint64_t offset = getFirstIntValue(insertOp.getOffsets());
+ uint64_t offset = insertOp.getOffsets().front();
if (isa<spirv::ScalarType>(srcVector.getType())) {
assert(!isa<spirv::ScalarType>(dstVector.getType()));
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index e2d42e961c576..941644e1116fc 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -550,11 +550,8 @@ struct ExtensionOverExtractStridedSlice final
if (failed(ext))
return failure();
- VectorType origTy = op.getType();
- VectorType extractTy =
- origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
- op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
+ op.getLoc(), ext->getIn(), op.getOffsets(), op.getSizes(),
op.getStrides());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..dda6b916176fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1340,13 +1340,6 @@ LogicalResult vector::ExtractOp::verify() {
return success();
}
-template <typename IntType>
-static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
- return llvm::to_vector<4>(llvm::map_range(
- arrayAttr.getAsRange<IntegerAttr>(),
- [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
-}
-
/// Fold the result of chains of ExtractOp in place by simply concatenating the
/// positions.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
@@ -1770,8 +1763,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
return Value();
// Trim offsets for dimensions fully extracted.
- auto sliceOffsets =
- extractVector<int64_t>(extractStridedSliceOp.getOffsets());
+ SmallVector<int64_t> sliceOffsets(extractStridedSliceOp.getOffsets());
while (!sliceOffsets.empty()) {
size_t lastOffset = sliceOffsets.size() - 1;
if (sliceOffsets.back() != 0 ||
@@ -1825,12 +1817,10 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
insertOp.getSourceVectorType().getRank();
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
- auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
+ ArrayRef<int64_t> insertOffsets = insertOp.getOffsets();
ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
- if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getInt() != 1;
- }))
+ if (insertOp.hasNonUnitStrides())
return Value();
bool disjoint = false;
SmallVector<int64_t, 4> offsetDiffs;
@@ -2899,6 +2889,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+ SmallVector<int64_t> offsets;
+ SmallVector<int64_t> sizes;
+ SmallVector<int64_t> strides;
+ bool parsedNonStridedOffsets = false;
+ while (succeeded(parser.parseOptionalLSquare())) {
+ int64_t offset = 0;
+ if (parser.parseInteger(offset))
+ return {};
+ if (parser.parseOptionalColon()) {
+ // Case 1: [Offset, ...]
+ if (!strides.empty() || parsedNonStridedOffsets) {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected slice stride or size");
+ return {};
+ }
+ offsets.push_back(offset);
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseCommaSeparatedList(
+ AsmParser::Delimiter::None, [&]() -> ParseResult {
+ if (parser.parseInteger(offset))
+ return failure();
+ offsets.push_back(offset);
+ return success();
+ })) {
+ return {};
+ }
+ }
+ if (parser.parseRSquare())
+ return {};
+ parsedNonStridedOffsets = true;
+ continue;
+ }
+ int64_t sizeOrStide = 0;
+ if (parser.parseInteger(sizeOrStide)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected slice stride or size");
+ return {};
+ }
+ if (parser.parseOptionalColon()) {
+ // Case 2: [Offset:Stride]
+ if (!sizes.empty() || parser.parseRSquare()) {
+ parser.emitError(parser.getCurrentLocation(), "expected slice size");
+ return {};
+ }
+ offsets.push_back(offset);
+ strides.push_back(sizeOrStide);
+ continue;
+ }
+ // Case 3: [Offset:Size:Stride]
+ if (sizes.size() < strides.size()) {
+ parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+ return {};
+ }
+ int64_t stride = 0;
+ if (parser.parseInteger(stride) || parser.parseRSquare()) {
+ parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+ return {};
+ }
+ offsets.push_back(offset);
+ sizes.push_back(sizeOrStide);
+ strides.push_back(stride);
+ }
+ return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+ ArrayRef<int64_t> offsets = getOffsets();
+ ArrayRef<int64_t> sizes = getSizes();
+ ArrayRef<int64_t> strides = getStrides();
+ int nonStridedOffsets = offsets.size() - strides.size();
+ if (nonStridedOffsets > 0) {
+ printer << '[';
+ llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+ printer << ']';
+ }
+ for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+ int strideIdx = d - nonStridedOffsets;
+ printer << '[' << offsets[d] << ':';
+ if (!sizes.empty())
+ printer << sizes[strideIdx] << ':';
+ printer << strides[strideIdx] << ']';
+ }
+}
+
//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//
@@ -2907,26 +2986,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
Value source, Value dest,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) {
- result.addOperands({source, dest});
- auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
- auto stridesAttr = getVectorSubscriptAttr(builder, strides);
- result.addTypes(dest.getType());
- result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
- offsetsAttr);
- result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
- stridesAttr);
-}
-
-// TODO: Should be moved to Tablegen ConfinedAttr attributes.
-template <typename OpType>
-static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
- ArrayAttr arrayAttr,
- ArrayRef<int64_t> shape,
- StringRef attrName) {
- if (arrayAttr.size() > shape.size())
- return op.emitOpError("expected ")
- << attrName << " attribute of rank no greater than vector rank";
- return success();
+ build(builder, result, source, dest,
+ StridedSliceAttr::get(builder.getContext(), offsets, strides));
}
// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
@@ -2934,16 +2995,15 @@ static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
-isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
- int64_t max, StringRef attrName,
- bool halfOpen = true) {
- for (auto attr : arrayAttr) {
- auto val = llvm::cast<IntegerAttr>(attr).getInt();
+isIntArrayConfinedToRange(OpType op, ArrayRef<int64_t> array, int64_t min,
+ int64_t max, StringRef arrayName,
+ bool halfOpen = true) {
+ for (int64_t val : array) {
auto upper = max;
if (!halfOpen)
upper += 1;
if (val < min || val >= upper)
- return op.emitOpError("expected ") << attrName << " to be confined to ["
+ return op.emitOpError("expected ") << arrayName << " to be confined to ["
<< min << ", " << upper << ")";
}
return success();
@@ -2954,13 +3014,12 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
// Otherwise, the admissible interval is [min, max].
template <typename OpType>
static LogicalResult
-isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
- ArrayRef<int64_t> shape, StringRef attrName,
- bool halfOpen = true, int64_t min = 0) {
- for (auto [index, attrDimPair] :
- llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
- int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
- int64_t max = std::get<1>(attrDimPair);
+isIntArrayConfinedToShape(OpType op, ArrayRef<int64_t> array,
+ ArrayRef<int64_t> shape, StringRef attrName,
+ bool halfOpen = true, int64_t min = 0) {
+ for (auto [index, dimPair] : llvm::enumerate(llvm::zip_first(array, shape))) {
+ int64_t val, max;
+ std::tie(val, max) = dimPair;
if (!halfOpen)
max += 1;
if (val < min || val >= max)
@@ -2977,40 +3036,32 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
// the admissible interval is [min, max].
template <typename OpType>
-static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
- OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
- ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
+static LogicalResult isSumOfIntArrayConfinedToShape(
+ OpType op, ArrayRef<int64_t> array1, ArrayRef<int64_t> array2,
+ ArrayRef<int64_t> shape, StringRef arrayName1, StringRef arrayName2,
bool halfOpen = true, int64_t min = 1) {
- assert(arrayAttr1.size() <= shape.size());
- assert(arrayAttr2.size() <= sh...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/101850
More information about the Mlir-commits
mailing list