[Mlir-commits] [mlir] c8d3b0c - [MLIR][XeGPU] Add distribution for vector.create_mask from Wg to Sg (#169571)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 3 16:01:50 PST 2025


Author: Nishant Patel
Date: 2025-12-03T16:01:46-08:00
New Revision: c8d3b0c8e33076107fb633fdc2a7e4c734dc3f26

URL: https://github.com/llvm/llvm-project/commit/c8d3b0c8e33076107fb633fdc2a7e4c734dc3f26
DIFF: https://github.com/llvm/llvm-project/commit/c8d3b0c8e33076107fb633fdc2a7e4c734dc3f26.diff

LOG: [MLIR][XeGPU] Add distribution for vector.create_mask from Wg to Sg (#169571)

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 48bd0662b03ff..e7182ed8d05f7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1278,15 +1278,15 @@ struct WgToSgVectorTransposeOp
   }
 };
 
-// This pattern distributes the vector.constant_mask ops to work at subgroup
-// level.
-struct WgToSgVectorConstantMaskOp
-    : public OpConversionPattern<vector::ConstantMaskOp> {
-  using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+// Distribute vector mask ops to work at subgroup level.
+template <typename MaskOpType>
+struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
+  using OpConversionPattern<MaskOpType>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      MaskOpType op,
+      typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
     xegpu::DistributeLayoutAttr layout =
         xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
@@ -1296,9 +1296,16 @@ struct WgToSgVectorConstantMaskOp
     VectorType type = op.getResult().getType();
     auto wgShape = type.getShape();
 
-    ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
+    SmallVector<Value> wgMaskDimSizes;
+    if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) {
+      for (int64_t maskSize : op.getMaskDimSizes()) {
+        wgMaskDimSizes.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, maskSize));
+      }
+    } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) {
+      wgMaskDimSizes = llvm::to_vector(op.getOperands());
+    }
 
-    // Get subgroup ID.
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
     auto sgOffsets =
@@ -1310,19 +1317,17 @@ struct WgToSgVectorConstantMaskOp
     VectorType resultType = VectorType::get(sgShape, type.getElementType());
 
     // In each dimension, each subgroup computes its local mask size as:
-    // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d])
+    // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d])
     SmallVector<Value> newCreateMaskOps;
     for (auto offsetSet : *sgOffsets) {
       SmallVector<Value> maskOperands;
 
-      for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
-        Value wgMaskSizeVal =
-            arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize);
+      for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) {
         Value dimSizeVal =
             arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
         Value offset = offsetSet[i];
         Value adjustedMaskSize =
-            arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
+            arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset);
         Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
         Value nonNegative =
             arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
@@ -1343,6 +1348,8 @@ struct WgToSgVectorConstantMaskOp
   }
 };
 
+using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
+using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
 } // namespace
 
 namespace mlir {
@@ -1358,7 +1365,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
            WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
-           WgToSgVectorConstantMaskOp>(patterns.getContext());
+           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1485,9 +1493,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<
-      vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
-      vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+                               vector::TransposeOp, vector::BroadcastOp,
+                               vector::MultiDimReductionOp,
+                               vector::ConstantMaskOp, vector::CreateMaskOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index eae51a16053d8..c95c64084f3f8 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -135,4 +135,12 @@ gpu.module @test_distribution {
     %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
     gpu.return
   }
+
+  gpu.func @vector_create_mask_2D() {
+    // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1>
+    // CHECK-NOT: vector.create_mask
+    %cst16 = arith.constant 16 : index
+    %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
+    gpu.return
+  } 
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 98920d61c4f58..ae581c24fbf20 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -580,6 +580,43 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: vector_create_mask_1D
+  gpu.func @vector_create_mask_1D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index
+    // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1>
+    %cst8 = arith.constant 8 : index
+    %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_create_mask_2D
+  gpu.func @vector_create_mask_2D() {
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]]
+    // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]]
+    // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]]
+    // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index
+    // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index
+    // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index
+    // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1>
+    %cst16 = arith.constant 16 : index
+    %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1>
+    gpu.return
+  }
+
   // CHECK-LABEL: distribute_load_slice_attr
   gpu.func @distribute_load_slice_attr() {
     %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>


        


More information about the Mlir-commits mailing list