[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)
James Newling
llvmlistbot at llvm.org
Wed May 7 10:10:51 PDT 2025
================
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it's less than or equal to 0,
+ // create a zero mask, else strip the first operand and create a mask
+ // using the second operand.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero =
+ rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+ auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+ createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+ zero);
+ auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+ createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+ // Use a select operation to choose between the masks.
+ auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
----------------
newling wrote:
Rather than create a zeroMask and use a SelectOp, it would be simpler to multiply isZeroOrNegative by adaptor.getOperands().back() and use that as the argument to the new mask.
https://github.com/llvm/llvm-project/pull/138214
More information about the Mlir-commits
mailing list