[Mlir-commits] [mlir] [MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (PR #111614)
Mehdi Amini
llvmlistbot at llvm.org
Wed Oct 9 01:00:58 PDT 2024
================
@@ -3772,6 +3772,92 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
}
};
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+/// Before:
+/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+/// After:
+/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+/// vector<1x1x1x1x8xi8>
+///
+class ContiguousExtractStridedSliceToExtract final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.hasNonUnitStrides()) {
+ return failure();
+ }
+ Value source = op.getOperand();
+ auto sourceType = cast<VectorType>(source.getType());
+ if (sourceType.isScalable() || sourceType.getRank() == 0) {
+ return failure();
+ }
+
+ // Compute the number of offsets to pass to ExtractOp::build. That is the
+ // difference between the source rank and the desired slice rank. We walk
+ // the dimensions from innermost out, and stop when the next slice dimension
+ // is not full-size.
+ SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+ int numOffsets;
+ for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
+ if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+ break;
+ }
+ }
+
+ // If the created extract op would have no offsets, then this whole
+ // extract_strided_slice is the identity and should have been handled by
+ // other canonicalizations.
+ if (numOffsets == 0) {
+ return failure();
+ }
+
+ // If not even the inner-most dimension is full-size, this op can't be
+ // rewritten as an ExtractOp.
+ if (numOffsets == sourceType.getRank() &&
+ static_cast<int>(sizes.size()) == sourceType.getRank()) {
+ return failure();
+ }
+
+ // The outer dimensions must have unit size.
+ for (int i = 0; i < numOffsets; ++i) {
+ if (sizes[i] != 1) {
+ return failure();
+ }
----------------
joker-eph wrote:
Nit: extra braces on this if (and the 3 others above)
https://github.com/llvm/llvm-project/pull/111614
More information about the Mlir-commits
mailing list