[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)
Nishant Patel
llvmlistbot at llvm.org
Mon May 12 19:52:32 PDT 2025
================
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+/// is converted to:
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.muli %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.create<mlir::arith::MulIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ loc, dstTy, maskSize.getResult());
----------------
nbpatel wrote:
done
https://github.com/llvm/llvm-project/pull/138214
More information about the Mlir-commits
mailing list