[Mlir-commits] [mlir] [mlir][vector] Update syntax and representation of insert/extract_strided_slice (PR #101850)

Mehdi Amini llvmlistbot at llvm.org
Mon Aug 5 13:48:36 PDT 2024


================
@@ -2899,6 +2883,95 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// StridedSliceAttr
+//===----------------------------------------------------------------------===//
+
+Attribute StridedSliceAttr::parse(AsmParser &parser, Type attrType) {
+  SmallVector<int64_t> offsets;
+  SmallVector<int64_t> sizes;
+  SmallVector<int64_t> strides;
+  bool parsedNonStridedOffsets = false;
+  while (succeeded(parser.parseOptionalLSquare())) {
+    int64_t offset = 0;
+    if (parser.parseInteger(offset))
+      return {};
+    if (parser.parseOptionalColon()) {
+      // Case 1: [Offset, ...]
+      if (!strides.empty() || parsedNonStridedOffsets) {
+        parser.emitError(parser.getCurrentLocation(),
+                         "expected slice stride or size");
+        return {};
+      }
+      offsets.push_back(offset);
+      if (succeeded(parser.parseOptionalComma())) {
+        if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None,
+                                           [&]() -> ParseResult {
+                                             if (parser.parseInteger(offset))
+                                               return failure();
+                                             offsets.push_back(offset);
+                                             return success();
+                                           })) {
+          return {};
+        }
+      }
+      if (parser.parseRSquare())
+        return {};
+      parsedNonStridedOffsets = true;
+      continue;
+    }
+    int64_t sizeOrStide = 0;
+    if (parser.parseInteger(sizeOrStide)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected slice stride or size");
+      return {};
+    }
+    if (parser.parseOptionalColon()) {
+      // Case 2: [Offset:Stride]
+      if (!sizes.empty() || parser.parseRSquare()) {
+        parser.emitError(parser.getCurrentLocation(), "expected slice size");
+        return {};
+      }
+      offsets.push_back(offset);
+      strides.push_back(sizeOrStide);
+      continue;
+    }
+    // Case 3: [Offset:Size:Stride]
+    if (sizes.size() < strides.size()) {
+      parser.emitError(parser.getCurrentLocation(), "unexpected slice size");
+      return {};
+    }
+    int64_t stride = 0;
+    if (parser.parseInteger(stride) || parser.parseRSquare()) {
+      parser.emitError(parser.getCurrentLocation(), "expected slice stride");
+      return {};
+    }
+    offsets.push_back(offset);
+    sizes.push_back(sizeOrStide);
+    strides.push_back(stride);
+  }
+  return StridedSliceAttr::get(parser.getContext(), offsets, sizes, strides);
+}
+
+void StridedSliceAttr::print(AsmPrinter &printer) const {
+  ArrayRef<int64_t> offsets = getOffsets();
+  ArrayRef<int64_t> sizes = getSizes();
+  ArrayRef<int64_t> strides = getStrides();
+  int nonStridedOffsets = offsets.size() - strides.size();
+  if (nonStridedOffsets > 0) {
+    printer << '[';
+    llvm::interleaveComma(offsets.take_front(nonStridedOffsets), printer);
+    printer << ']';
+  }
+  for (int d = nonStridedOffsets, e = offsets.size(); d < e; ++d) {
+    int strideIdx = d - nonStridedOffsets;
+    printer << '[' << offsets[d] << ':';
+    if (!sizes.empty())
+      printer << sizes[strideIdx] << ':';
+    printer << strides[strideIdx] << ']';
+  }
+}
----------------
joker-eph wrote:

There seems to be an assumption about the size of all the arrays, but I don't see the invariant enforced in a verifier?

https://github.com/llvm/llvm-project/pull/101850


More information about the Mlir-commits mailing list