[Mlir-commits] [mlir] [MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (PR #111614)

Mehdi Amini llvmlistbot at llvm.org
Wed Oct 9 01:00:58 PDT 2024


================
@@ -3772,6 +3772,92 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
   }
 };
 
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+///     Before:
+///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+///     After:
+///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+///         vector<1x1x1x1x8xi8>
+///
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    Value source = op.getOperand();
+    auto sourceType = cast<VectorType>(source.getType());
+    if (sourceType.isScalable() || sourceType.getRank() == 0) {
+      return failure();
+    }
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // difference between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+    int numOffsets;
+    for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If the created extract op would have no offsets, then this whole
+    // extract_strided_slice is the identity and should have been handled by
+    // other canonicalizations.
+    if (numOffsets == 0) {
+      return failure();
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank() &&
+        static_cast<int>(sizes.size()) == sourceType.getRank()) {
+      return failure();
+    }
+
+    // The outer dimensions must have unit size.
+    for (int i = 0; i < numOffsets; ++i) {
+      if (sizes[i] != 1) {
+        return failure();
+      }
----------------
joker-eph wrote:

Nit: extra braces on this if (and the 3 others above)

https://github.com/llvm/llvm-project/pull/111614


More information about the Mlir-commits mailing list