[Mlir-commits] [mlir] [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (PR #111541)
Benoit Jacob
llvmlistbot at llvm.org
Tue Oct 8 08:32:15 PDT 2024
================
@@ -329,12 +329,76 @@ class DecomposeNDExtractStridedSlice
}
};
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+ SmallVectorImpl<int64_t> &results) {
+ for (auto attr : arrayAttr)
+ results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+}
+
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+class ContiguousExtractStridedSliceToExtract final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.hasNonUnitStrides()) {
+ return failure();
+ }
+ SmallVector<int64_t> sizes;
+ populateFromInt64AttrArray(op.getSizes(), sizes);
+ Value source = op.getOperand();
+ ShapedType sourceType = cast<ShapedType>(source.getType());
+
+ // Compute the number of offsets to pass to ExtractOp::build. That is the
+ // difference between the source rank and the desired slice rank. We walk
+ // the dimensions from innermost out, and stop when the next slice dimension
+ // is not full-size.
+ int numOffsets;
+ for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
+ if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+ break;
+ }
+ }
+
+ // If not even the inner-most dimension is full-size, this op can't be
+ // rewritten as an ExtractOp.
+ if (numOffsets == sourceType.getRank()) {
+ return failure();
+ }
+
+ // Avoid generating slices that have unit outer dimensions. The shape_cast
+ // op that we create below would take bad generic fallback patterns
+ // (ShapeCastOpRewritePattern).
+ while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
+ ++numOffsets;
+ }
+
+ SmallVector<int64_t> offsets;
+ populateFromInt64AttrArray(op.getOffsets(), offsets);
+ auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+ Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+ extractOffsets);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op->getResultTypes()[0], extract);
----------------
bjacob wrote:
TIL
https://github.com/llvm/llvm-project/pull/111541
More information about the Mlir-commits
mailing list