[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg (PR #168118)

Nishant Patel llvmlistbot at llvm.org
Tue Nov 18 11:37:17 PST 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/168118

>From 4d54fb13f7707b1d50ac8ed4548581ad55f2272e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 13 Nov 2025 22:37:16 +0000
Subject: [PATCH 1/2] Add distribution for vector mask operations

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 57 ++++++++++++++++++-
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 13 +++++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 20 +++++++
 3 files changed, 88 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..afab880d173c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1283,6 +1283,57 @@ struct WgToSgVectorTransposeOp
   }
 };
 
+/// Pattern for lowering vector.create_mask and vector.constant_mask ops to
+/// subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+  using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      MaskOpType op,
+      typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newMaskOps;
+    for (int i = 0; i < count; ++i) {
+      Value newMaskOp;
+      if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+        newMaskOp = vector::CreateMaskOp::create(
+            rewriter, op.getLoc(), newResultType, op.getOperands());
+      } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+        newMaskOp = vector::ConstantMaskOp::create(
+            rewriter, op.getLoc(), newResultType, op.getMaskDimSizes());
+      } else {
+        return rewriter.notifyMatchFailure(op,
+                                           "Unsupported mask operation type");
+      }
+      xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp),
+                                     layout.dropSgLayoutAndData());
+
+      newMaskOps.push_back(newMaskOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newMaskOps});
+    return success();
+  }
+};
+
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+
 } // namespace
 
 namespace mlir {
@@ -1297,7 +1348,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
+           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
           patterns.getContext());
 }
 } // namespace xegpu
@@ -1427,7 +1479,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
                                vector::TransposeOp, vector::BroadcastOp,
-                               vector::MultiDimReductionOp>(
+                               vector::MultiDimReductionOp,
+                               vector::ConstantMaskOp, vector::CreateMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 84ce80f477a55..b587ecc726f4d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -130,5 +130,18 @@ gpu.module @test_distribution {
     %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
       gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    %cst16 = arith.constant 16 : index
+    // CHECK: %[[CST16:.*]] = arith.constant 16 : index
+    // CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1>
+    // CHECK-NOT: vector.create_mask
+    // CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1>
+    // CHECK-NOT: vector.constant_mask
+    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    gpu.return
+  }
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..f254b82c6401f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,24 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_mask_1D
+  gpu.func @vector_mask_1D() {
+    %cst8 = arith.constant 8 : index
+    // CHECK: vector.create_mask {{.*}} : vector<16xi1>
+    %create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1>
+    // CHECK: vector.constant_mask [8] : vector<16xi1>
+    %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_mask_2D
+  gpu.func @vector_mask_2D() {
+    %cst16 = arith.constant 16 : index
+    // CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1>
+    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    // CHECK: vector.constant_mask [16, 16] : vector<32x32xi1>
+    %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    gpu.return
+  }
 }

>From be59f467b838735a035ab846f2d32cf74b8405dd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 14 Nov 2025 05:07:02 +0000
Subject: [PATCH 2/2] Add pattern for constant mask

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 99 +++++++++++--------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  6 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 29 ++++--
 3 files changed, 81 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index afab880d173c7..81fd25a155129 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1283,57 +1283,74 @@ struct WgToSgVectorTransposeOp
   }
 };
 
-/// Pattern for lowering vector.create_mask and vector.constant_mask ops to
-/// subgroup level.
-template <typename MaskOpType>
-struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
-  using OpConversionPattern<MaskOpType>::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      MaskOpType op,
-      typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
-      ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = op.getResult().getType();
-    ArrayRef<int64_t> wgShape = resultType.getShape();
+// This pattern distributes the vector.constant_mask ops to work at subgroup
+// level.
+struct WgToSgVectorConstantMaskOp
+    : public OpConversionPattern<vector::ConstantMaskOp> {
+  using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
 
+  LogicalResult
+  matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> sgShape;
-    int count;
-    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
-    VectorType newResultType =
-        VectorType::get(sgShape, resultType.getElementType());
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+
+    ArrayRef<int64_t> originalMaskDimSizes = op.getMaskDimSizes();
 
-    SmallVector<Value> newMaskOps;
-    for (int i = 0; i < count; ++i) {
-      Value newMaskOp;
-      if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
-        newMaskOp = vector::CreateMaskOp::create(
-            rewriter, op.getLoc(), newResultType, op.getOperands());
-      } else if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
-        newMaskOp = vector::ConstantMaskOp::create(
-            rewriter, op.getLoc(), newResultType, op.getMaskDimSizes());
-      } else {
-        return rewriter.notifyMatchFailure(op,
-                                           "Unsupported mask operation type");
+    // Get subgroup ID.
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto sgOffsets =
+        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType resultType = VectorType::get(sgShape, type.getElementType());
+
+    SmallVector<Value> newCreateMaskOps;
+    for (auto offsetSet : *sgOffsets) {
+      SmallVector<Value> maskOperands;
+
+      for (auto [i, originalMaskSize] : llvm::enumerate(originalMaskDimSizes)) {
+        Value originalMaskSizeVal =
+            arith::ConstantIndexOp::create(rewriter, loc, originalMaskSize);
+        Value dimSizeVal =
+            arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
+        Value offset = offsetSet[i];
+        // Compute: originalMaskSize - offset.
+        Value adjustedMaskSize =
+            arith::SubIOp::create(rewriter, loc, originalMaskSizeVal, offset);
+        // Clamp to [0, dimSize]: max(0, min(adjustedMaskSize,
+        // dimSize))
+        Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value clampedLow =
+            arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+        Value clampedHigh =
+            arith::MinSIOp::create(rewriter, loc, clampedLow, dimSizeVal);
+        maskOperands.push_back(clampedHigh);
       }
-      xegpu::setDistributeLayoutAttr(cast<OpResult>(newMaskOp),
-                                     layout.dropSgLayoutAndData());
 
-      newMaskOps.push_back(newMaskOp);
+      auto newCreateMaskOp =
+          vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newCreateMaskOps.push_back(newCreateMaskOp.getResult());
     }
 
-    rewriter.replaceOpWithMultiple(op, {newMaskOps});
+    rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
     return success();
   }
 };
 
-using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
-using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
-
 } // namespace
 
 namespace mlir {
@@ -1349,8 +1366,7 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
            WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
-           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
-          patterns.getContext());
+           WgToSgVectorConstantMaskOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1477,10 +1493,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
-                               vector::TransposeOp, vector::BroadcastOp,
-                               vector::MultiDimReductionOp,
-                               vector::ConstantMaskOp, vector::CreateMaskOp>(
+  target.addDynamicallyLegalOp<
+      vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
+      vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index b587ecc726f4d..a752d0aa5c541 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -134,12 +134,8 @@ gpu.module @test_distribution {
   // CHECK-LABEL: vector_mask_2D
   gpu.func @vector_mask_2D() {
     %cst16 = arith.constant 16 : index
-    // CHECK: %[[CST16:.*]] = arith.constant 16 : index
-    // CHECK-COUNT-4: vector.create_mask %[[CST16:.*]], %[[CST16]] : vector<16x16xi1>
+    // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
     // CHECK-NOT: vector.create_mask
-    // CHECK-COUNT-4: vector.constant_mask [16, 16] : vector<16x16xi1>
-    // CHECK-NOT: vector.constant_mask
-    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
     %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
     gpu.return
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index f254b82c6401f..fa08ed1623501 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -550,20 +550,37 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: vector_mask_1D
   gpu.func @vector_mask_1D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+    // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
     %cst8 = arith.constant 8 : index
-    // CHECK: vector.create_mask {{.*}} : vector<16xi1>
-    %create_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<16xi1>
-    // CHECK: vector.constant_mask [8] : vector<16xi1>
     %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
     gpu.return
   }
 
   // CHECK-LABEL: vector_mask_2D
   gpu.func @vector_mask_2D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+    // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+    // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+    // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+    // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index
+    // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+    // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index
+    // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
     %cst16 = arith.constant 16 : index
-    // CHECK: vector.create_mask {{.*}}, {{.*}} : vector<32x32xi1>
-    %create_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
-    // CHECK: vector.constant_mask [16, 16] : vector<32x32xi1>
     %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
     gpu.return
   }



More information about the Mlir-commits mailing list