[Mlir-commits] [mlir] [mlir][vector] Linearize vector.create_mask (PR #138760)

James Newling llvmlistbot at llvm.org
Tue May 6 13:52:51 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/138760

Alternative to https://github.com/llvm/llvm-project/pull/138214 

Rather than using an arith.select, this PR computes a new rank-1 mask argument directly with scalars before splatting. This PR also supports more mask shapes. 

FYI @nbpatel I started trying to suggest this approach in https://github.com/llvm/llvm-project/pull/138214 but thought it'd be clearer to just implement it. Please let me know what you think 

>From 9bc615fa528e617e9ecb9f8f237bad5db44109f3 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 11:23:16 -0700
Subject: [PATCH 1/3] linearize create_mask

---
 .../Vector/Transforms/VectorLinearize.cpp     | 59 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 14 +++++
 2 files changed, 71 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..eec846ff717a1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -20,6 +21,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
+#include <algorithm>
 #include <cstdint>
 #include <numeric>
 #include <optional>
@@ -469,6 +471,11 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
   return containsScalableResult;
 }
 
+static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
+  auto shape = createMaskOp.getType().getShape();
+  return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
+}
+
 static bool isNotLinearizable(Operation *op) {
 
   // Only ops that are in the vector dialect, are ConstantLike, or
@@ -485,6 +492,12 @@ static bool isNotLinearizable(Operation *op) {
   if (isNotLinearizableBecauseScalable(op))
     return true;
 
+  if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
+    if (!isLinearizableCreateMaskOp(createMaskOp)) {
+      return true;
+    }
+  }
+
   return false;
 }
 
@@ -527,12 +540,54 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
       });
 }
 
+/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
+///
+/// ```
+/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+/// ```
+///
+/// becomes
+///
+/// ```
+/// %0 = arith.muli %arg0, %arg1 : index
+/// %1 = arith.muli %0, %arg2 : index
+/// %2 = vector.create_mask %1: vector<16xi1>
+/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
+/// ```
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType maskType = createMaskOp.getType();
+    assert(isLinearizableCreateMaskOp(createMaskOp));
+
+    Value product = adaptor.getOperands().front();
+    for (unsigned i = 1; i < maskType.getRank(); ++i) {
+      product = rewriter.create<mlir::arith::MulIOp>(
+          createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
+    }
+    Type flatMaskType = getTypeConverter()->convertType(maskType);
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        createMaskOp.getLoc(), flatMaskType, product);
+    rewriter.replaceOp(createMaskOp, newMask);
+    return success();
+  }
+};
+
 void mlir::vector::populateVectorLinearizeBasePatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorSplat>(
-      typeConverter, patterns.getContext());
+               LinearizeVectorCreateMask, LinearizeVectorBitCast,
+               LinearizeVectorSplat>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..e1e38cff5c733 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -345,3 +345,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: linearize_create_mask
+//  CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
+//       CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
+//       CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
+//       CHECK: %[[MASK:.*]]  = vector.create_mask %[[MULI2]] : vector<16xi1>
+//       CHECK: %[[CAST:.*]]  = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
+//       CHECK: return %[[CAST]] : vector<1x16x1xi1>
+func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
+  %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+  return %0 : vector<1x16x1xi1>
+}

>From b4ae361146f42fe0ab8ed40be630109b0477ea9f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 13:39:04 -0700
Subject: [PATCH 2/3] updates

---
 .../Vector/Transforms/VectorLinearize.cpp     | 85 ++++++++++++++-----
 mlir/test/Dialect/Vector/linearize.mlir       | 16 ++--
 2 files changed, 77 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index eec846ff717a1..ec089ed0b58c9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include <algorithm>
 #include <cstdint>
 #include <numeric>
@@ -471,9 +472,12 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
   return containsScalableResult;
 }
 
-static bool isLinearizableCreateMaskOp(vector::CreateMaskOp createMaskOp) {
-  auto shape = createMaskOp.getType().getShape();
-  return llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) <= 1;
+static bool
+isCreateMaskWithAtMostOneNonUnit(vector::CreateMaskOp createMaskOp) {
+  ArrayRef<int64_t> shape = createMaskOp.getType().getShape();
+  bool multipleNonUnitDim =
+      llvm::count_if(shape, [](int64_t dim) { return dim > 1; }) > 1;
+  return !multipleNonUnitDim;
 }
 
 static bool isNotLinearizable(Operation *op) {
@@ -493,7 +497,7 @@ static bool isNotLinearizable(Operation *op) {
     return true;
 
   if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(op)) {
-    if (!isLinearizableCreateMaskOp(createMaskOp)) {
+    if (!isCreateMaskWithAtMostOneNonUnit(createMaskOp)) {
       return true;
     }
   }
@@ -540,7 +544,8 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
       });
 }
 
-/// Linearize vector.create_mask with at most 1 non-unit dimension. Example:
+/// Linearize a vector.create_mask that has at most 1 non-unit dimension.
+/// Example:
 ///
 /// ```
 /// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
@@ -549,11 +554,30 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 /// becomes
 ///
 /// ```
-/// %0 = arith.muli %arg0, %arg1 : index
-/// %1 = arith.muli %0, %arg2 : index
-/// %2 = vector.create_mask %1: vector<16xi1>
+/// [...]
+/// %2 = vector.create_mask %prod: vector<16xi1>
 /// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
 /// ```
+///
+/// where %prod above the product of the (clamped) dimension-wise masking ranges
+/// %arg0, %arg1, and %arg2.
+///
+/// This is equivalent to choosing the rank-1 masking range as:
+/// 1) %arg1 if %arg0 and %arg2 are stricty positive
+/// 2) 0     if either %arg0 or %arg2 are 0 or negative.
+///
+/// Specifically, %prod is obtained as
+///
+/// ```
+/// %true = arith.constant true
+/// %zero = arith.constant 0 : index
+/// %0    = arith.cmpi sgt, %arg0, %zero : index
+/// %1    = arith.muli %true, %0 : i1
+/// %2    = arith.cmpi sgt, %arg2, %zero : index
+/// %3    = arith.muli %1, %2 : i1
+/// %4    = arith.index_cast %3 : i1 to index
+/// %prod = arith.muli %4, %arg1 : index
+/// ```
 struct LinearizeVectorCreateMask final
     : OpConversionPattern<vector::CreateMaskOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -563,21 +587,44 @@ struct LinearizeVectorCreateMask final
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
-  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::CreateMaskOp maskOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    VectorType maskType = createMaskOp.getType();
-    assert(isLinearizableCreateMaskOp(createMaskOp));
+    VectorType type = maskOp.getType();
+    assert(isCreateMaskWithAtMostOneNonUnit(maskOp) &&
+           "expected linearizable create_mask");
+
+    Location loc = maskOp.getLoc();
+
+    // First, get the product of (clamped) mask sizes in the unit-dimensions.
+    Value prod = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    int nonUnitDim = -1;
+    for (unsigned i = 0; i < type.getRank(); ++i) {
+      auto v = adaptor.getOperands()[i];
+      auto dimSize = type.getDimSize(i);
+      if (dimSize <= 1) {
+        Value nxt = rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::sgt, v, zero);
+        prod = rewriter.create<arith::MulIOp>(loc, prod, nxt);
+      } else {
+        assert(nonUnitDim == -1 && "at most 1 non-unit expected");
+        nonUnitDim = i;
+      }
+    }
+    prod =
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), prod);
 
-    Value product = adaptor.getOperands().front();
-    for (unsigned i = 1; i < maskType.getRank(); ++i) {
-      product = rewriter.create<mlir::arith::MulIOp>(
-          createMaskOp.getLoc(), product, adaptor.getOperands()[i]);
+    // Finally, multiply by the size in the dimension that is not unit.
+    if (nonUnitDim != -1) {
+      Value v = adaptor.getOperands()[nonUnitDim];
+      prod = rewriter.create<arith::MulIOp>(loc, prod, v);
     }
-    Type flatMaskType = getTypeConverter()->convertType(maskType);
-    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
-        createMaskOp.getLoc(), flatMaskType, product);
-    rewriter.replaceOp(createMaskOp, newMask);
+
+    Type flatType = getTypeConverter()->convertType(type);
+    auto newMask =
+        rewriter.create<mlir::vector::CreateMaskOp>(loc, flatType, prod);
+    rewriter.replaceOp(maskOp, newMask);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e1e38cff5c733..6437b5eefa9bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -350,11 +350,17 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
 
 // CHECK-LABEL: linearize_create_mask
 //  CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> vector<1x16x1xi1>
-//       CHECK: %[[MULI1:.*]] = arith.muli %[[ARG0]], %[[ARG1]] : index
-//       CHECK: %[[MULI2:.*]] = arith.muli %[[MULI1]], %[[ARG2]] : index
-//       CHECK: %[[MASK:.*]]  = vector.create_mask %[[MULI2]] : vector<16xi1>
-//       CHECK: %[[CAST:.*]]  = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
-//       CHECK: return %[[CAST]] : vector<1x16x1xi1>
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+//       CHECK: %[[CMP0:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
+//       CHECK: %[[MUL0:.*]] = arith.muli %[[TRUE]], %[[CMP0]] : i1
+//       CHECK: %[[CMP1:.*]] = arith.cmpi sgt, %[[ARG2]], %[[C0]] : index
+//       CHECK: %[[MUL1:.*]] = arith.muli %[[MUL0]], %[[CMP1]] : i1
+//       CHECK: %[[CAST:.*]] = arith.index_cast %[[MUL1]] : i1 to index
+//       CHECK: %[[MUL2:.*]] = arith.muli %[[CAST]], %[[ARG1]] : index
+//       CHECK: %[[MASK:.*]] = vector.create_mask %[[MUL2]] : vector<16xi1>
+//       CHECK: %[[CAST2:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16x1xi1>
+//       CHECK: return %[[CAST2]] : vector<1x16x1xi1>
 func.func @linearize_create_mask(%arg0 : index, %arg1 : index, %arg2 : index) -> vector<1x16x1xi1> {
   %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
   return %0 : vector<1x16x1xi1>

>From 692e4f2c309d16e09084ebc67f59e2de8f25e731 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 13:43:12 -0700
Subject: [PATCH 3/3] tidy

---
 .../Vector/Transforms/VectorLinearize.cpp      | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index ec089ed0b58c9..86bbbc2196a8b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -545,18 +545,16 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 }
 
 /// Linearize a vector.create_mask that has at most 1 non-unit dimension.
-/// Example:
-///
+/// For example,
 /// ```
-/// %0 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
+/// %mask3 = vector.create_mask %arg0, %arg1, %arg2: vector<1x16x1xi1>
 /// ```
 ///
-/// becomes
-///
+/// becomes,
 /// ```
 /// [...]
-/// %2 = vector.create_mask %prod: vector<16xi1>
-/// %3 = vector.shape_cast %2: vector<16xi1> to vector<1x16x1xi1>
+/// %mask1 = vector.create_mask %prod: vector<16xi1>
+/// %mask3 = vector.shape_cast %mask1: vector<16xi1> to vector<1x16x1xi1>
 /// ```
 ///
 /// where %prod above the product of the (clamped) dimension-wise masking ranges
@@ -601,11 +599,11 @@ struct LinearizeVectorCreateMask final
     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     int nonUnitDim = -1;
     for (unsigned i = 0; i < type.getRank(); ++i) {
-      auto v = adaptor.getOperands()[i];
-      auto dimSize = type.getDimSize(i);
+      Value dimRange = adaptor.getOperands()[i];
+      int64_t dimSize = type.getDimSize(i);
       if (dimSize <= 1) {
         Value nxt = rewriter.create<arith::CmpIOp>(
-            loc, arith::CmpIPredicate::sgt, v, zero);
+            loc, arith::CmpIPredicate::sgt, dimRange, zero);
         prod = rewriter.create<arith::MulIOp>(loc, prod, nxt);
       } else {
         assert(nonUnitDim == -1 && "at most 1 non-unit expected");



More information about the Mlir-commits mailing list