[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for subgroup_id_range (PR #148661)
Nishant Patel
llvmlistbot at llvm.org
Mon Jul 14 09:26:22 PDT 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/148661
None
>From 1d18b895bface3094ac4868601ebeadc0ae2758c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 11 Jul 2025 14:26:08 +0000
Subject: [PATCH] Add support for subgroup_id_range
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 40 ++++++++-
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 83 +++++++++++++++++++
2 files changed, 122 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be7b860dd1729..56dc132d8083d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
}
+ // Check if there is warp specialization.
+ auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
+ int64_t &endRange) -> bool {
+ Operation *parent = op->getParentOp();
+ // Find the outermost scf::IfOp with xegpu.sg_id_range.
+ while (parent) {
+ if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
+ if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
+ if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
+ auto values = denseAttr.asArrayRef();
+ if (values.size() == 2) {
+ startRange = values[0];
+ endRange = values[1];
+ }
+ }
+ break;
+ }
+ }
+ parent = parent->getParentOp();
+ }
+ // Return false if startRange is 0
+ return (startRange > 0 && endRange > startRange);
+ };
+
+ int64_t startRange = -1, endRange = -1;
+ bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
+
+ // If warp specialization is detected, adjust the subgroup id accordingly
+ Value adjustedSgId = linearSgId;
+ if (warpSpecialized) {
+ // Subtract startRange from the original subgroup id to get the adjusted
+ // sg id
+ Value startRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startRange);
+ adjustedSgId =
+ rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
+ }
+
auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+ affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
if (failed(deLinearizeSgId))
return failure();
SmallVector<Value> sgIds = *deLinearizeSgId;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..71eb732ac4953 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: @warp_specialized
+ gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c31 = arith.constant 31 : index
+ %c3 = arith.constant 3 : index
+ %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %cond {
+ // CHECK-NOT: index.sub
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ } {xegpu.sg_id_range = array<i32: 0, 1>}
+ %cond3 = arith.cmpi sge, %sg_id, %c1 : index
+ %cond4 = arith.cmpi slt, %sg_id, %c2 : index
+ %cond5 = arith.andi %cond3, %cond4 : i1
+ scf.if %cond5 {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
+ %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
+ -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+ -> vector<128x256xf32>
+ %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+ }{xegpu.sg_id_range = array<i32: 1, 2>}
+ %cond6 = arith.cmpi sge, %sg_id, %c2 : index
+ %cond7 = arith.cmpi slt, %sg_id, %c31 : index
+ %cond8 = arith.andi %cond6, %cond7 : i1
+ scf.if %cond8 {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+ %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }{xegpu.sg_id_range = array<i32: 2, 32>}
+ gpu.return
+ }
+ // CHECK-LABEL: @subgroup_id_range_nested_if
+ gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+ %sg_id = gpu.subgroup_id : index
+ %c1 = arith.constant 1 : i1
+ %c3 = arith.constant 3 : index
+ %c32 = arith.constant 32 : index
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+ %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+ %cond = arith.andi %cond1, %cond2 : i1
+ scf.if %c1 {
+ scf.if %cond {
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+ %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ %ld = xegpu.load_nd %td
+ : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+ -> vector<128x64xf32>
+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+ }
+ } {xegpu.sg_id_range = array<i32: 3, 8>}
+ gpu.return
+ }
}
More information about the Mlir-commits
mailing list