[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Jul 4 07:36:39 PDT 2025
================
@@ -4081,6 +4081,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
+// CreateMaskOp.
+//
+// Example:
+//
+// %mask = vector.create_mask %ub : vector<16xi1>
+// %slice = vector.extract_strided_slice [%offset] [8] [1]
+//
+// to
+//
+// %new_ub = arith.subi %ub, %offset
+// %mask = vector.create_mask %new_ub : vector<8xi1>
+class StridedSliceCreateMaskFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+public:
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = extractStridedSliceOp.getLoc();
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // CreateMaskOp.
+ auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+ auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+ if (!createMaskOp)
+ return failure();
+ // Return if 'extractStridedSliceOp' has non-unit strides.
+ if (extractStridedSliceOp.hasNonUnitStrides())
+ return failure();
+ // Gather constant mask dimension sizes.
+ SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+ // Gather strided slice offsets and sizes.
+ SmallVector<int64_t> sliceOffsets;
+ populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+ sliceOffsets);
+ SmallVector<int64_t> sliceSizes;
+ populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+ // Compute slice of vector mask region.
+ SmallVector<Value> sliceMaskDimSizes;
+ sliceMaskDimSizes.reserve(maskDimSizes.size());
+ for (auto [maskDimSize, sliceOffset, sliceSize] :
+ llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+ // No need to clamp on min/max values, because create_mask has clamping
+ // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+ // greater than the vector dim size.
+ IntegerAttr offsetAttr =
+ rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+ Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+ Value sliceMaskDimSize =
+ rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+ sliceMaskDimSizes.push_back(sliceMaskDimSize);
+ }
+ // Add unchanged dimensions.
+ if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+ for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e;
+ ++i) {
+ sliceMaskDimSizes.push_back(maskDimSizes[i]);
+ }
+ }
----------------
kuhar wrote:
nit: Could this be `llvm::append_range(sliceMaskDimSizes, llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()))`?
https://github.com/llvm/llvm-project/pull/146745
More information about the Mlir-commits
mailing list