[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)
Mehdi Amini
llvmlistbot at llvm.org
Wed Jul 2 11:50:14 PDT 2025
================
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+class StridedSliceCreateMaskFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+public:
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = extractStridedSliceOp.getLoc();
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // CreateMaskOp.
+ auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+ auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+ if (!createMaskOp)
+ return failure();
+ // Return if 'extractStridedSliceOp' has non-unit strides.
+ if (extractStridedSliceOp.hasNonUnitStrides())
+ return failure();
+ // Gather constant mask dimension sizes.
+ SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+ // Gather strided slice offsets and sizes.
+ SmallVector<int64_t, 4> sliceOffsets;
----------------
joker-eph wrote:
```suggestion
SmallVector<int64_t> sliceOffsets;
```
Nit, the default is good enough in general.
https://github.com/llvm/llvm-project/pull/146745
More information about the Mlir-commits
mailing list