[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)

Jakub Kuderski llvmlistbot at llvm.org
Fri Jul 4 07:36:39 PDT 2025


================
@@ -4081,6 +4081,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
+// CreateMaskOp.
+//
+// Example:
+//
+// %mask = vector.create_mask %ub : vector<16xi1>
+// %slice = vector.extract_strided_slice [%offset] [8] [1]
+//
+// to
+//
+// %new_ub = arith.subi %ub, %offset
+// %mask = vector.create_mask %new_ub : vector<8xi1>
+class StridedSliceCreateMaskFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = extractStridedSliceOp.getLoc();
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // CreateMaskOp.
+    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+    if (!createMaskOp)
+      return failure();
----------------
kuhar wrote:

Isn't there something like `getDefiningOp<T>()`?

https://github.com/llvm/llvm-project/pull/146745


More information about the Mlir-commits mailing list