[Mlir-commits] [mlir] [mlir] [vector] Add linearization pattern for vector.create_mask (PR #138214)

Nishant Patel llvmlistbot at llvm.org
Mon May 12 19:52:32 PDT 2025


================
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
   }
 };
 
+/// This pattern converts the CreateMaskOp to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+/// is converted to:
+///   %zero = arith.constant 0 : index
+///   %cmpi = arith.cmpi sgt, %arg0, %zero : index
+///   %index = arith.index_cast %cmpi : i1 to index
+///   %mul = arith.muli %index, %arg1 : index
+///   %mask = vector.create_mask %mul : vector<4xi1>
+///   %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = createMaskOp.getLoc();
+    VectorType srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it is greater than 0, the
+    // corresponding mask element is set to true, otherwise false.
+    // The result of the comparison is then multiplied with
+    // the second operand of create_mask to get the 1D mask.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+    auto isNonZero = rewriter.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
+    auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), isNonZero);
+    auto secondOperand = adaptor.getOperands().back();
+    auto maskSize = rewriter.create<mlir::arith::MulIOp>(
+        loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        loc, dstTy, maskSize.getResult());
----------------
nbpatel wrote:

done

https://github.com/llvm/llvm-project/pull/138214


More information about the Mlir-commits mailing list