[Mlir-commits] [mlir] [mlir][vector] Linearize vector.create_mask (PR #138760)
James Newling
llvmlistbot at llvm.org
Tue May 6 13:52:51 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/138760
Alternative to https://github.com/llvm/llvm-project/pull/138214
Rather than using an arith.select, this PR computes a new rank-1 mask argument directly with scalars before splatting. This PR also supports more mask shapes.
FYI @nbpatel I started trying to suggest this approach in https://github.com/llvm/llvm-project/pull/138214 but thought it'd be clearer to just implement it. Please let me know what you think
>From 9bc615fa528e617e9ecb9f8f237bad5db44109f3 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 11:23:16 -0700
Subject: [PATCH 1/3] linearize create_mask
---
.../Vector/Transforms/VectorLinearize.cpp | 59 ++++++++++++++++++-
mlir/test/Dialect/Vector/linearize.mlir | 14 +++++
2 files changed, 71 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..eec846ff717a1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -20,6 +21,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
+#include <algorithm>
#include <cstdint>
#include <numeric>
#include <optional>
@@ -469,6 +471,11 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
return containsScalableResult;
}
+static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
+ auto shape = createMaskOp.getType().getShape();
+ return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
+}
+
static bool isNotLinearizable(Operation *op) {
// Only ops that are in the vector dialect, are ConstantLike, or
@@ -485,6 +492,12 @@ static bool isNotLinearizable(Operation *op) {
if (isNotLinearizableBecauseScalable(op))
return true;
+ if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
+ if (!isLinearizableCreateMaskOp(createMaskOp)) {
+ return true;
+ }
+ }
+
return false;
}
@@ -527,12 +540,54 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
});
}
+/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
+///
+/// ```
+/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+/// ```
+///
+/// becomes
+///
+/// ```
+/// %0 = arith.muli %arg0, %arg1 : index
+/// %1 = arith.muli %0, %arg2 : index
+/// %2 = vector.create_mask %1: vector<16xi1>
+/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
+/// ```
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType maskType = createMaskOp.getType();
+ assert(isLinearizableCreateMaskOp(createMaskOp));
+
+ Value product = adaptor.getOperands().front();
+ for (unsigned i = 1; i < maskType.getRank(); ++i) {
+ product = rewriter.create<mlir::arith::MulIOp>(
+ createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
+ }
+ Type flatMaskType = getTypeConverter()->convertType(maskType);
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ createMaskOp.getLoc(), flatMaskType, product);
+ rewriter.replaceOp(createMaskOp, newMask);
+ return success();
+ }
+};
+
void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorSplat>(
- typeConverter, patterns.getContext());
+ LinearizeVectorCreateMask, LinearizeVectorBitCast,
+ LinearizeVectorSplat>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..e1e38cff5c733 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -345,3 +345,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
+
+// -----
+
+// CHECK-LABEL: linearize_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
+// CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
+// CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[MULI2]] : vector<16xi1>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
+// CHECK: return %[[CAST]] : vector<1x16x1xi1>
+func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
+ %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+ return %0 : vector<1x16x1xi1>
+}
>From b4ae361146f42fe0ab8ed40be630109b0477ea9f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 13:39:04 -0700
Subject: [PATCH 2/3] updates
---
.../Vector/Transforms/VectorLinearize.cpp | 85 ++++++++++++++-----
mlir/test/Dialect/Vector/linearize.mlir | 16 ++--
2 files changed, 77 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index eec846ff717a1..ec089ed0b58c9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include <cstdint>
#include <numeric>
@@ -471,9 +472,12 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
return containsScalableResult;
}
-static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
- auto shape = createMaskOp.getType().getShape();
- return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
+static bool
+isCreateMaskWithAtMostOneNonUnit(vector::CreateMaskOp createMaskOp) {
+ ArrayRef<int64_t> shape = createMaskOp.getType().getShape();
+ bool multipleNonUnitDim =
+ llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) > 1;
+ return !multipleNonUnitDim;
}
static bool isNotLinearizable(Operation *op) {
@@ -493,7 +497,7 @@ static bool isNotLinearizable(Operation *op) {
return true;
if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
- if (!isLinearizableCreateMaskOp(createMaskOp)) {
+ if (!isCreateMaskWithAtMostOneNonUnit(createMaskOp)) {
return true;
}
}
@@ -540,7 +544,8 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
});
}
-/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
+/// Linearize a vector.create_mask that has at most 1 non-unit dimension.
+/// Example:
///
/// ```
/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
@@ -549,11 +554,30 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
/// becomes
///
/// ```
-/// %0 = arith.muli %arg0, %arg1 : index
-/// %1 = arith.muli %0, %arg2 : index
-/// %2 = vector.create_mask %1: vector<16xi1>
+/// [...]
+/// %2 = vector.create_mask %prod: vector<16xi1>
/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
/// ```
+///
+/// where %prod above the product of the (clamped) dimension-wise masking ranges
+/// %arg0, %arg1, and %arg2.
+///
+/// This is equivalent to choosing the rank-1 masking range as:
+/// 1) %arg1 if %arg0 and %arg2 are stricty positive
+/// 2) 0 if either %arg0 or %arg2 are 0 or negative.
+///
+/// Specifically, %prod is obtained as
+///
+/// ```
+/// %true = arith.constant true
+/// %zero = arith.constant 0 : index
+/// %0 = arith.cmpi sgt, %arg0, %zero : index
+/// %1 = arith.muli %true, %0 : i1
+/// %2 = arith.cmpi sgt, %arg2, %zero : index
+/// %3 = arith.muli %1, %2 : i1
+/// %4 = arith.index_cast %3 : i1 to index
+/// %prod = arith.muli %4, %arg1 : index
+/// ```
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
using OpConversionPattern::OpConversionPattern;
@@ -563,21 +587,44 @@ struct LinearizeVectorCreateMask final
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
- matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::CreateMaskOp maskOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType maskType = createMaskOp.getType();
- assert(isLinearizableCreateMaskOp(createMaskOp));
+ VectorType type = maskOp.getType();
+ assert(isCreateMaskWithAtMostOneNonUnit(maskOp) &&
+ "expected linearizable create_mask");
+
+ Location loc = maskOp.getLoc();
+
+ // First, get the product of (clamped) mask sizes in the unit-dimensions.
+ Value prod = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ int nonUnitDim = -1;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ auto v = adaptor.getOperands()[i];
+ auto dimSize = type.getDimSize(i);
+ if (dimSize <= 1) {
+ Value nxt = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, v, zero);
+ prod = rewriter.create<arith::MulIOp>(loc, prod, nxt);
+ } else {
+ assert(nonUnitDim == -1 && "at most 1 non-unit expected");
+ nonUnitDim = i;
+ }
+ }
+ prod =
+ rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), prod);
- Value product = adaptor.getOperands().front();
- for (unsigned i = 1; i < maskType.getRank(); ++i) {
- product = rewriter.create<mlir::arith::MulIOp>(
- createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
+ // Finally, multiply by the size in the dimension that is not unit.
+ if (nonUnitDim != -1) {
+ Value v = adaptor.getOperands()[nonUnitDim];
+ prod = rewriter.create<arith::MulIOp>(loc, prod, v);
}
- Type flatMaskType = getTypeConverter()->convertType(maskType);
- auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
- createMaskOp.getLoc(), flatMaskType, product);
- rewriter.replaceOp(createMaskOp, newMask);
+
+ Type flatType = getTypeConverter()->convertType(type);
+ auto newMask =
+ rewriter.create<mlir::vector::CreateMaskOp>(loc, flatType, prod);
+ rewriter.replaceOp(maskOp, newMask);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e1e38cff5c733..6437b5eefa9bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -350,11 +350,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
-// CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
-// CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
-// CHECK: %[[MASK:.*]] = vector.create_mask %[[MULI2]] : vector<16xi1>
-// CHECK: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
-// CHECK: return %[[CAST]] : vector<1x16x1xi1>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[CMP0:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MUL0:.*]] = arith.muli %[[TRUE]], %[[CMP0]] : i1
+// CHECK: %[[CMP1:.*]] = arith.cmpi sgt, %[[ARG2]], %[[C0]] : index
+// CHECK: %[[MUL1:.*]] = arith.muli %[[MUL0]], %[[CMP1]] : i1
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[MUL1]] : i1 to index
+// CHECK: %[[MUL2:.*]] = arith.muli %[[CAST]], %[[ARG1]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[MUL2]] : vector<16xi1>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
+// CHECK: return %[[CAST2]] : vector<1x16x1xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
%0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
return %0 : vector<1x16x1xi1>
>From 692e4f2c309d16e09084ebc67f59e2de8f25e731 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 13:43:12 -0700
Subject: [PATCH 3/3] tidy
---
.../Vector/Transforms/VectorLinearize.cpp | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index ec089ed0b58c9..86bbbc2196a8b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -545,18 +545,16 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
}
/// Linearize a vector.create_mask that has at most 1 non-unit dimension.
-/// Example:
-///
+/// For example,
/// ```
-/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+/// %mask3 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
/// ```
///
-/// becomes
-///
+/// becomes,
/// ```
/// [...]
-/// %2 = vector.create_mask %prod: vector<16xi1>
-/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
+/// %mask1 = vector.create_mask %prod: vector<16xi1>
+/// %mask3 = vector.shape_cast %mask1: vector<16xi1> to vector<1x16x1xi1>
/// ```
///
/// where %prod above the product of the (clamped) dimension-wise masking ranges
@@ -601,11 +599,11 @@ struct LinearizeVectorCreateMask final
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
int nonUnitDim = -1;
for (unsigned i = 0; i < type.getRank(); ++i) {
- auto v = adaptor.getOperands()[i];
- auto dimSize = type.getDimSize(i);
+ Value dimRange = adaptor.getOperands()[i];
+ int64_t dimSize = type.getDimSize(i);
if (dimSize <= 1) {
Value nxt = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, v, zero);
+ loc, arith::CmpIPredicate::sgt, dimRange, zero);
prod = rewriter.create<arith::MulIOp>(loc, prod, nxt);
} else {
assert(nonUnitDim == -1 && "at most 1 non-unit expected");
More information about the Mlir-commits
mailing list