[Mlir-commits] [mlir] 1778d3b - [mlir] [vector] Add linearization pattern for vector.create_mask (#138214)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 14 15:54:01 PDT 2025
Author: Nishant Patel
Date: 2025-05-14T15:53:58-07:00
New Revision: 1778d3b8245b9a7787bbd0b00f60f879ed4689c9
URL: https://github.com/llvm/llvm-project/commit/1778d3b8245b9a7787bbd0b00f60f879ed4689c9
DIFF: https://github.com/llvm/llvm-project/commit/1778d3b8245b9a7787bbd0b00f60f879ed4689c9.diff
LOG: [mlir] [vector] Add linearization pattern for vector.create_mask (#138214)
This PR is a breakdown [3 / 4] of the PR #136193
The PR adds linearization patterns for vector.create_mask
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c7169c5297d9a..90e0479a515d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -566,6 +566,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a linearized vector.
+/// It currently supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %arg0, %arg1 : vector<1x4xi1>
+/// is converted to:
+/// %zero = arith.constant 0 : index
+/// %cmpi = arith.cmpi sgt, %arg0, %zero : index
+/// %index = arith.index_cast %cmpi : i1 to index
+/// %mul = arith.andi %index, %arg1 : index
+/// %mask = vector.create_mask %mul : vector<4xi1>
+/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = createMaskOp.getLoc();
+ VectorType srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it is greater than 0, the
+ // corresponding mask element is set to true, otherwise false.
+ // The result of the comparison is then multiplied with
+ // the second operand of create_mask to get the 1D mask.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ auto isNonZero = rewriter.createOrFold<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
+ auto isNonZeroIndex = rewriter.createOrFold<mlir::arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), isNonZero);
+ auto secondOperand = adaptor.getOperands().back();
+ auto maskSize = rewriter.createOrFold<mlir::arith::AndIOp>(
+ loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand);
+
+ auto newMask =
+ rewriter.create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
+ rewriter.replaceOp(createMaskOp, newMask);
+ return success();
+ }
+};
+
} // namespace
/// Return true if the operation `op` does not support scalable vectors and
@@ -651,9 +709,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorSplat>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 3cdbef8db604b..40445d3781228 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -416,3 +416,28 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
return %0 : vector<4x[2]xi32>
}
+// -----
+
+// CHECK-LABEL: linearize_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
+func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
+ // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
+ // CHECK: %[[MULI:.*]] = arith.andi %[[INDEXCAST]], %[[ARG1]] : index
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[MULI]] : vector<16xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK_1D]] : vector<16xi1> to vector<1x16xi1>
+ // CHECK: return %[[CAST]] : vector<1x16xi1>
+ %0 = vector.create_mask %arg0, %arg1 : vector<1x16xi1>
+ return %0 : vector<1x16xi1>
+}
+
+// -----
+// CHECK-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vector<1x[16]xi1> {
+
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask {{%.*}} : vector<[16]xi1>
+ %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
+ return %0 : vector<1x[16]xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..54defd949c264 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,7 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
+ registry.insert<vector::VectorDialect, arith::ArithDialect>();
}
void runOnOperation() override {
More information about the Mlir-commits
mailing list