[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)

Mehdi Amini llvmlistbot at llvm.org
Wed Jul 2 11:49:28 PDT 2025


================
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+class StridedSliceCreateMaskFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = extractStridedSliceOp.getLoc();
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // CreateMaskOp.
+    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+    if (!createMaskOp)
+      return failure();
+    // Return if 'extractStridedSliceOp' has non-unit strides.
+    if (extractStridedSliceOp.hasNonUnitStrides())
+      return failure();
+    // Gather constant mask dimension sizes.
+    SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+    // Gather strided slice offsets and sizes.
+    SmallVector<int64_t, 4> sliceOffsets;
+    populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+                               sliceOffsets);
+    SmallVector<int64_t, 4> sliceSizes;
+    populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+    // Compute slice of vector mask region.
+    SmallVector<Value> sliceMaskDimSizes;
+    sliceMaskDimSizes.reserve(maskDimSizes.size());
+    for (auto [maskDimSize, sliceOffset, sliceSize] :
+         llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+      // No need to clamp on min/max values, because create_mask has clamping
+      // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+      // greater than the vector dim size.
+      IntegerAttr offsetAttr =
+          rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+      Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+      Value sliceMaskDimSize =
+          rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+      sliceMaskDimSizes.push_back(sliceMaskDimSize);
+    }
+    // Add unchanged dimensions.
+    if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
----------------
joker-eph wrote:

There is a branch here, so I would think you need two tests to provide coverage of the taken/not-taken cases?

https://github.com/llvm/llvm-project/pull/146745


More information about the Mlir-commits mailing list