[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)

Diego Caballero llvmlistbot at llvm.org
Tue May 6 16:29:29 PDT 2025


================
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
   }
 };
 
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it's less than or equal to 0,
+    // create a zero mask, else strip the first operand and create a mask
+    // using the second operand.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero =
+        rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+    auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+        createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+        zero);
+    auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+        createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+    // Use a select operation to choose between the masks.
+    auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+        createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+    auto result = rewriter.create<mlir::arith::SelectOp>(
+        createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+    rewriter.replaceOp(createMaskOp, result.getResult());
----------------
dcaballe wrote:

Could you please illustrate this example in the doc above? The current doc only shows a create mask + shape cast

https://github.com/llvm/llvm-project/pull/138214


More information about the Mlir-commits mailing list