[Mlir-commits] [mlir] [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (PR #111541)
Benoit Jacob
llvmlistbot at llvm.org
Tue Oct 8 13:24:27 PDT 2024
================
@@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+void populateVectorContiguousExtractStridedSliceToExtractPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
----------------
bjacob wrote:
Two things (that happened in the testcase I looked at, where these ops where extracting parts from a matrix tile to feed into GPU matrix multiplication intrinsics):
1. `extract` is more constrained than `extract_strided_slice`, so it is more likely to have a good lowering.
2. my use case was, similarly to the test added in this PR, a `extract_strided_slice` producing a vector with leading unit dims, followed by a `shape_cast` dropping the unit dims. That `shape_cast` was hitting the fallback lowering pattern, `ShapeCastOpRewritePattern`. Now that the `extract_strided_slice` is rewritten into a pair (`extract`, `shape_cast`), the two `shape_cast` fold.
https://github.com/llvm/llvm-project/pull/111541
More information about the Mlir-commits
mailing list