[Mlir-commits] [mlir] [MLIR][Vector] Add distribution pattern for `vector::ConstantMaskOp` (PR #172268)

Artem Kroviakov llvmlistbot at llvm.org
Tue Dec 16 00:48:01 PST 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/172268

>From 6d81029e652b6f4771dc06bc21b385c029592520 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 15 Dec 2025 09:41:57 +0000
Subject: [PATCH 1/2] [MLIR][Vector] Add distribution pattern for
 `vector::ConstantMaskOp`

---
 .../Vector/Transforms/VectorDistribute.cpp    | 28 ++++++++++++----
 .../Vector/vector-warp-distribute.mlir        | 33 +++++++++++++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b5e950733a22..90d6901089525 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1123,26 +1123,42 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
+        getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
+    if (!yieldOperand)
+      yieldOperand =
+          getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
     if (!yieldOperand)
       return failure();
 
-    auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
+    if (!mask)
+      mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
 
     // Early exit if any values needed for calculating the new mask indices
     // are defined inside the warp op.
-    if (!llvm::all_of(mask->getOperands(), [&](Value value) {
+    if (mask->getOperands().size() &&
+        !llvm::all_of(mask->getOperands(), [&](Value value) {
           return warpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 
-    Location loc = mask.getLoc();
+    Location loc = mask->getLoc();
     unsigned operandIndex = yieldOperand->getOperandNumber();
 
     auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
-    VectorType seqType = mask.getVectorType();
+    VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
     ArrayRef<int64_t> seqShape = seqType.getShape();
     ArrayRef<int64_t> distShape = distType.getShape();
+    SmallVector<Value> materializedOperands;
+    if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
+      materializedOperands.append(createMaskOp.getOperands().begin(),
+                                  createMaskOp.getOperands().end());
+    } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+      auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
+      for (auto dimSize : dimSizes)
+        materializedOperands.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+    }
 
     rewriter.setInsertionPointAfter(warpOp);
 
@@ -1170,7 +1186,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
       // mask sizes are always in the range [0, mask_vector_size[i]).
       Value maskDimIdx = affine::makeComposedAffineApply(
           rewriter, loc, s1 - s0 * distShape[i],
-          {delinearizedIds[i], mask.getOperand(i)});
+          {delinearizedIds[i], materializedOperands[i]});
       newOperands.push_back(maskDimIdx);
     }
 
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0cf6dd151e16c..135db02d543ef 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1779,6 +1779,21 @@ func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref
 //       CHECK-DIST-AND-PROP:   %[[CST:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_read {{.*}} : memref<128xf32>, vector<4xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[CST]], {{.*}} : vector<4xf32>, memref<128xf32>
+// -----
+
+func.func @warp_propagate_constant_mask(%laneid: index) -> vector<1xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xi1>) {
+    %1 = vector.constant_mask [1] : vector<32xi1>
+    gpu.yield %1 : vector<32xi1>
+  }
+  return %r : vector<1xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUB:.*]] = affine_map<()[s0] -> (-s0 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[MDIST:.+]] = affine.apply #[[$SUB]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[MDIST]] : vector<1xi1>
 
 // -----
 
@@ -1813,6 +1828,24 @@ func.func @warp_propagate_multi_dim_create_mask(%laneid: index, %m0: index, %m1:
 //       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[M0]], %[[LANEID]]]
 //       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[M1]], %[[LANEID]]]
 //       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[M2]] : vector<1x2x4xi1>
+// -----
+
+func.func @warp_propagate_multi_dim_constant_mask(%laneid: index) -> vector<1x2x4xi1> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2x4xi1>) {
+    %1 = vector.constant_mask [1, 1, 2]: vector<16x4x4xi1>
+    gpu.yield %1 : vector<16x4x4xi1>
+  }
+  return %r : vector<1x2x4xi1>
+}
+
+//   CHECK-PROP-DAG: #[[$SUBM0:.*]] = affine_map<()[s0] -> (-(s0 floordiv 2) + 1)>
+//   CHECK-PROP-DAG: #[[$SUBM1:.*]] = affine_map<()[s0] -> (s0 * -2 + (s0 floordiv 2) * 4 + 1)>
+// CHECK-PROP-LABEL: func @warp_propagate_multi_dim_constant_mask
+//  CHECK-PROP-SAME: %[[LANEID:.+]]: index
+//       CHECK-PROP:   %[[CST2:.+]] = arith.constant 2 : index
+//       CHECK-PROP:   %[[DISTM0:.+]] = affine.apply #[[$SUBM0]]()[%[[LANEID]]]
+//       CHECK-PROP:   %[[DISTM1:.+]] = affine.apply #[[$SUBM1]]()[%[[LANEID]]]
+//       CHECK-PROP:   vector.create_mask %[[DISTM0]], %[[DISTM1]], %[[CST2]] : vector<1x2x4xi1>
 
 // -----
 

>From cbf6a9412a9868963fb2d7d57d0fbb3f8d2fb22b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 16 Dec 2025 08:47:36 +0000
Subject: [PATCH 2/2] Use templates

---
 .../Vector/Transforms/VectorDistribute.cpp    | 35 +++++++++----------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 90d6901089525..1ed6907023c21 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1118,21 +1118,18 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
 /// %cmp = arith.cmpi ult, %laneid, %0
 /// %ub = arith.select %cmp, %c0, %c1
 /// %1 = vector.create_mask %ub : vector<1xi1>
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
 struct WarpOpCreateMask : public WarpDistributionPattern {
   using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand =
-        getWarpResult(warpOp, (llvm::IsaPred<vector::CreateMaskOp>));
-    if (!yieldOperand)
-      yieldOperand =
-          getWarpResult(warpOp, (llvm::IsaPred<vector::ConstantMaskOp>));
+    OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
     if (!yieldOperand)
       return failure();
 
-    Operation *mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
-    if (!mask)
-      mask = yieldOperand->get().getDefiningOp<vector::ConstantMaskOp>();
+    Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
 
     // Early exit if any values needed for calculating the new mask indices
     // are defined inside the warp op.
@@ -1150,10 +1147,11 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
     ArrayRef<int64_t> seqShape = seqType.getShape();
     ArrayRef<int64_t> distShape = distType.getShape();
     SmallVector<Value> materializedOperands;
-    if (auto createMaskOp = dyn_cast<vector::CreateMaskOp>(mask)) {
-      materializedOperands.append(createMaskOp.getOperands().begin(),
-                                  createMaskOp.getOperands().end());
-    } else if (auto constantMaskOp = dyn_cast<vector::ConstantMaskOp>(mask)) {
+    if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
+      materializedOperands.append(mask->getOperands().begin(),
+                                  mask->getOperands().end());
+    } else {
+      auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
       auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
       for (auto dimSize : dimSizes)
         materializedOperands.push_back(
@@ -2298,12 +2296,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
-           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
-          patterns.getContext(), benefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
+               WarpOpCreateMask<vector::CreateMaskOp>,
+               WarpOpCreateMask<vector::ConstantMaskOp>,
+               WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,



More information about the Mlir-commits mailing list