[Mlir-commits] [mlir] [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (PR #111541)

Benoit Jacob llvmlistbot at llvm.org
Tue Oct 8 08:32:15 PDT 2024


================
@@ -329,12 +329,76 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+                                       SmallVectorImpl<int64_t> &results) {
+  for (auto attr : arrayAttr)
+    results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+}
+
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    SmallVector<int64_t> sizes;
+    populateFromInt64AttrArray(op.getSizes(), sizes);
+    Value source = op.getOperand();
+    ShapedType sourceType = cast<ShapedType>(source.getType());
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // difference between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    int numOffsets;
+    for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank()) {
+      return failure();
+    }
+
+    // Avoid generating slices that have unit outer dimensions. The shape_cast
+    // op that we create below would take bad generic fallback patterns
+    // (ShapeCastOpRewritePattern).
+    while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
+      ++numOffsets;
+    }
+
+    SmallVector<int64_t> offsets;
+    populateFromInt64AttrArray(op.getOffsets(), offsets);
+    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+                                                       extractOffsets);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op->getResultTypes()[0], extract);
----------------
bjacob wrote:

TIL

https://github.com/llvm/llvm-project/pull/111541


More information about the Mlir-commits mailing list