[Mlir-commits] [mlir] 65dec99 - [MLIR][XeGPU] Add support for subgroup_id_range (#148661)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 11:10:39 PDT 2025


Author: Nishant Patel
Date: 2025-07-23T11:10:35-07:00
New Revision: 65dec9956273309158f3feba6ea8f150ce995a2a

URL: https://github.com/llvm/llvm-project/commit/65dec9956273309158f3feba6ea8f150ce995a2a
DIFF: https://github.com/llvm/llvm-project/commit/65dec9956273309158f3feba6ea8f150ce995a2a.diff

LOG: [MLIR][XeGPU] Add support for subgroup_id_range (#148661)

This PR adds a new attribute to the xegpu dialect called xegpu.range.
One use case of this attribute can be to attach subgroup_id_range to
scf.if of to drive the execution.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index d022361d1e376..64eb21cbc3c4c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -336,4 +336,33 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
   let genVerifyDecl = 1;
 }
 
+def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
+  let summary = [{Specifies a half-open range}];
+  let description = [{
+    `RangeAttr` is an attribute that defines a half-open range [start, end).
+    The range is inclusive of the start value and exclusive of the end value.
+    One usage of this attribute can be to specify the subgroup id range.
+    The subgroup id range can be specified using this attribute,
+    and it can be attached to a scf.if op like
+    ```mlir
+    scf.if %cond {
+      // some operations
+    } {sg_id_range = #xegpu.range<[2, 4]>}
+    ```
+    In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.
+  }];
+
+  let parameters = (ins
+    "IntegerAttr": $start,
+    "IntegerAttr": $end
+  );
+
+  let builders = [
+    AttrBuilder<(ins "int":$start, "int":$end)>
+  ];
+
+  let assemblyFormat = "`<` `[`$start `,` $end `]` `>`";
+  let genVerifyDecl = 1;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ab404d52eab4..3c0ca114a62d4 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -211,6 +211,21 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_RangeAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                  IntegerAttr startOfRange, IntegerAttr endOfRange) {
+  if (startOfRange.getInt() >= endOfRange.getInt())
+    return emitError() << "'end' : " << endOfRange.getInt()
+                       << " must be greater than 'start' : "
+                       << startOfRange.getInt();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ef52323a9f46b..229a289838c60 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,6 +34,26 @@ using namespace mlir;
 
 namespace {
 
+// Check if there is sg id range attached to the scf.if op.
+static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
+                                 int64_t &endOfRange) {
+  Operation *parent = op->getParentOp();
+  // Find the outermost scf::IfOp with xegpu.sg_id_range.
+  while (parent) {
+    if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
+      if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
+              ifOp->getAttr("sg_id_range"))) {
+        startOfRange = attr.getStart().getInt();
+        endOfRange = attr.getEnd().getInt();
+        break;
+      }
+    }
+    parent = parent->getParentOp();
+  }
+  // Return false if startOfRange is 0
+  return (startOfRange > 0 && endOfRange > startOfRange);
+}
+
 static std::pair<SmallVector<int64_t>, int>
 getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   int count = 1;
@@ -174,8 +194,26 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
     }
 
+    int64_t startOfRange = -1, endOfRange = -1;
+    bool sgIdRangeSpecified =
+        isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+    Value adjustedSgId = linearSgId;
+    if (sgIdRangeSpecified) {
+      int64_t sgCount = endOfRange - startOfRange;
+      if (computeProduct(sgLayout) != sgCount)
+        return rewriter.notifyMatchFailure(
+            op, "sg_layout size must match the sg_id_range");
+      // Subtract startOfRange from the original subgroup id to get the adjusted
+      // sg id
+      Value startOfRangeVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+      adjustedSgId =
+          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+    }
+
     auto deLinearizeSgId =
-        affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+        affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
     if (failed(deLinearizeSgId))
       return failure();
     SmallVector<Value> sgIds = *deLinearizeSgId;

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 8a81a286da23a..d51122417fb61 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -327,5 +327,70 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     gpu.return
   }
-}
 
+  // CHECK-LABEL: @subgroup_id_range
+  gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c31 = arith.constant 31 : index
+    %c3 = arith.constant 3 : index
+    %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %cond {
+        // CHECK-NOT: index.sub
+        %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load =  xegpu.load_nd %tdesc
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+    } {sg_id_range = #xegpu.range<[0, 32]>}
+    %cond3 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c31 : index
+    %cond5 = arith.andi %cond3, %cond4 : i1
+    scf.if %cond5 {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C2:.*]] = arith.constant 2 : index
+      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+      %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+        -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      %load =  xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<128x64xf32>
+      %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }{sg_id_range = #xegpu.range<[2, 18]>}
+    gpu.return
+  }
+
+  // CHECK-LABEL: @subgroup_id_range_nested_if
+  gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c1 = arith.constant 1 : i1 
+    %c3 = arith.constant 3 : index
+    %c32 = arith.constant 32 : index
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %c1 {
+      scf.if %cond {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C3:.*]] = arith.constant 3 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+        %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+          -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %ld =  xegpu.load_nd %td
+          : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<128x64xf32>
+        %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }
+  } {sg_id_range = #xegpu.range<[3, 19]>}
+  gpu.return
+  }
+}


        


More information about the Mlir-commits mailing list