[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Aug 6 11:15:29 PDT 2024
================
@@ -2899,6 +2883,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+ SmallVector<int64_t> offsets;
+ SmallVector<int64_t> sizes;
+ SmallVector<int64_t> strides;
+ bool parsedNonStridedOffsets = false;
+ while (succeeded(parser.parseOptionalLSquare())) {
+ int64_t offset = 0;
+ if (parser.parseInteger(offset))
+ return {};
+ if (parser.parseOptionalColon()) {
+ // Case 1: [Offset, ...]
+ if (!strides.empty() || parsedNonStridedOffsets) {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected slice stride or size");
+ return {};
+ }
+ offsets.push_back(offset);
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None,
+ [&]() -> ParseResult {
+ if (parser.parseInteger(offset))
+ return failure();
+ offsets.push_back(offset);
+ return success();
+ })) {
+ return {};
+ }
+ }
+ if (parser.parseRSquare())
+ return {};
+ parsedNonStridedOffsets = true;
+ continue;
+ }
+ int64_t sizeOrStide = 0;
+ if (parser.parseInteger(sizeOrStide)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected slice stride or size");
+ return {};
+ }
+ if (parser.parseOptionalColon()) {
+ // Case 2: [Offset:Stride]
+ if (!sizes.empty() || parser.parseRSquare()) {
+ parser.emitError(parser.getCurrentLocation(), "expected slice size");
+ return {};
+ }
+ offsets.push_back(offset);
+ strides.push_back(sizeOrStide);
+ continue;
+ }
+ // Case 3: [Offset:Size:Stride]
+ if (sizes.size() < strides.size()) {
+ parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+ return {};
+ }
+ int64_t stride = 0;
+ if (parser.parseInteger(stride) || parser.parseRSquare()) {
+ parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+ return {};
+ }
+ offsets.push_back(offset);
+ sizes.push_back(sizeOrStide);
+ strides.push_back(stride);
+ }
+ return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+ ArrayRef<int64_t> offsets = getOffsets();
+ ArrayRef<int64_t> sizes = getSizes();
+ ArrayRef<int64_t> strides = getStrides();
+ int nonStridedOffsets = offsets.size() - strides.size();
+ if (nonStridedOffsets > 0) {
+ printer << '[';
+ llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+ printer << ']';
+ }
+ for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+ int strideIdx = d - nonStridedOffsets;
+ printer << '[' << offsets[d] << ':';
+ if (!sizes.empty())
+ printer << sizes[strideIdx] << ':';
+ printer << strides[strideIdx] << ']';
+ }
+}
----------------
MacDue wrote:
Going to leave this for a future patch now it's not needed for the printer/parser (which just uses the current syntax).
https://github.com/llvm/llvm-project/pull/101850
More information about the Mlir-commits
mailing list