[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for subgroup_id_range (PR #148661)

Nishant Patel llvmlistbot at llvm.org
Wed Jul 16 17:56:21 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/148661

>From 1d18b895bface3094ac4868601ebeadc0ae2758c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 11 Jul 2025 14:26:08 +0000
Subject: [PATCH 1/3] Add support for subgroup_id_range

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 40 ++++++++-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 83 +++++++++++++++++++
 2 files changed, 122 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be7b860dd1729..56dc132d8083d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -174,8 +174,46 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
+    // Check if there is warp specialization.
+    auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
+                                int64_t &endRange) -> bool {
+      Operation *parent = op->getParentOp();
+      // Find the outermost scf::IfOp with xegpu.sg_id_range.
+      while (parent) {
+        if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
+          if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
+            if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
+              auto values = denseAttr.asArrayRef();
+              if (values.size() == 2) {
+                startRange = values[0];
+                endRange = values[1];
+              }
+            }
+            break;
+          }
+        }
+        parent = parent->getParentOp();
+      }
+      // Return false if startRange is 0
+      return (startRange > 0 && endRange > startRange);
+    };
+
+    int64_t startRange = -1, endRange = -1;
+    bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
+
+    // If warp specialization is detected, adjust the subgroup id accordingly
+    Value adjustedSgId = linearSgId;
+    if (warpSpecialized) {
+      // Subtract startRange from the original subgroup id to get the adjusted
+      // sg id
+      Value startRangeVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, startRange);
+      adjustedSgId =
+          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
+    }
+
     auto deLinearizeSgId =
-        affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+        affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
     if (failed(deLinearizeSgId))
       return failure();
     SmallVector<Value> sgIds = *deLinearizeSgId;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..71eb732ac4953 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: @warp_specialized
+  gpu.func @warp_specialized(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c31 = arith.constant 31 : index
+    %c3 = arith.constant 3 : index
+    %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %cond {
+        // CHECK-NOT: index.sub
+        %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load =  xegpu.load_nd %tdesc
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+    } {xegpu.sg_id_range = array<i32: 0, 1>}
+    %cond3 = arith.cmpi sge, %sg_id, %c1 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c2 : index
+    %cond5 = arith.andi %cond3, %cond4 : i1
+     scf.if %cond5 {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C1:.*]] = arith.constant 1 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
+        %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load_a =  xegpu.load_nd %tdesc_a
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+        %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
+          -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+        %load_b =  xegpu.load_nd %tdesc_b
+          : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
+          -> vector<128x256xf32>
+        %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+     }{xegpu.sg_id_range = array<i32: 1, 2>}
+    %cond6 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond7 = arith.cmpi slt, %sg_id, %c31 : index
+    %cond8 = arith.andi %cond6, %cond7 : i1
+    scf.if %cond8 {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C2:.*]] = arith.constant 2 : index
+      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+      %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+        -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      %load =  xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<128x64xf32>
+      %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }{xegpu.sg_id_range = array<i32: 2, 32>}
+    gpu.return
+  }
 
+  // CHECK-LABEL: @subgroup_id_range_nested_if
+  gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c1 = arith.constant 1 : i1 
+    %c3 = arith.constant 3 : index
+    %c32 = arith.constant 32 : index
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %c1 {
+      scf.if %cond {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C3:.*]] = arith.constant 3 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+        %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+          -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %ld =  xegpu.load_nd %td
+          : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<128x64xf32>
+        %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }
+  } {xegpu.sg_id_range = array<i32: 3, 8>}
+  gpu.return
+  }
 }

>From b4e3068ca9d3e74d73ae9274834cc952d304a19f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 16 Jul 2025 17:02:23 +0000
Subject: [PATCH 2/3] Add xegpu.sg_id_range attribute

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 27 ++++++++++++++++
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 32 ++++++++-----------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  8 ++---
 3 files changed, 45 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 84c1dc1373ee5..306b6ec1eed16 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -315,4 +315,31 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
   let genVerifyDecl = 1;
 }
 
+def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
+  let summary = [{Specifies a half-open range}];
+  let description = [{
+    `RangeAttr` is an attribute that defines a half-open range [start, end).
+    The range is inclusive of the start value and exclusive of the end value.
+    One usage of this attribute can be for warp specialization.
+    For warp specialization, this attribute can be attached to a scf.if op like
+    ```mlir
+    scf.if %cond {
+      // some operations
+    }{sg_id_range = #xegpu.range<[2, 4]>}
+    ```
+    In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.
+  }];
+
+  let parameters = (ins
+    "IntegerAttr": $start,
+    "IntegerAttr": $end
+  );
+
+  let builders = [
+    AttrBuilder<(ins "int":$start, "int":$end)>
+  ];
+
+  let assemblyFormat = "`<` `[`$start ```,` $end `]``>`";
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 56dc132d8083d..eb89cca0070ac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -175,41 +175,37 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     }
 
     // Check if there is warp specialization.
-    auto isWarpSpecialized = [](Operation *op, int64_t &startRange,
-                                int64_t &endRange) -> bool {
+    auto isWarpSpecialized = [](Operation *op, int64_t &startOfRange,
+                                int64_t &endOfRange) -> bool {
       Operation *parent = op->getParentOp();
       // Find the outermost scf::IfOp with xegpu.sg_id_range.
       while (parent) {
         if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
-          if (Attribute attr = ifOp->getAttr("xegpu.sg_id_range")) {
-            if (auto denseAttr = dyn_cast<DenseI32ArrayAttr>(attr)) {
-              auto values = denseAttr.asArrayRef();
-              if (values.size() == 2) {
-                startRange = values[0];
-                endRange = values[1];
-              }
-            }
+          if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
+                  ifOp->getAttr("sg_id_range"))) {
+            startOfRange = attr.getStart().getInt();
+            endOfRange = attr.getEnd().getInt();
             break;
           }
         }
         parent = parent->getParentOp();
       }
-      // Return false if startRange is 0
-      return (startRange > 0 && endRange > startRange);
+      // Return false if startOfRange is 0
+      return (startOfRange > 0 && endOfRange > startOfRange);
     };
 
-    int64_t startRange = -1, endRange = -1;
-    bool warpSpecialized = isWarpSpecialized(op, startRange, endRange);
+    int64_t startOfRange = -1, endOfRange = -1;
+    bool warpSpecialized = isWarpSpecialized(op, startOfRange, endOfRange);
 
     // If warp specialization is detected, adjust the subgroup id accordingly
     Value adjustedSgId = linearSgId;
     if (warpSpecialized) {
-      // Subtract startRange from the original subgroup id to get the adjusted
+      // Subtract startOfRange from the original subgroup id to get the adjusted
       // sg id
-      Value startRangeVal =
-          rewriter.create<arith::ConstantIndexOp>(loc, startRange);
+      Value startOfRangeVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
       adjustedSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startRangeVal);
+          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
     }
 
     auto deLinearizeSgId =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 71eb732ac4953..39cd8c6158685 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -314,7 +314,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
         %load =  xegpu.load_nd %tdesc
           : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
           -> vector<256x128xf32>
-    } {xegpu.sg_id_range = array<i32: 0, 1>}
+    } {sg_id_range = #xegpu.range<[0, 1]>}
     %cond3 = arith.cmpi sge, %sg_id, %c1 : index
     %cond4 = arith.cmpi slt, %sg_id, %c2 : index
     %cond5 = arith.andi %cond3, %cond4 : i1
@@ -333,7 +333,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
           : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
           -> vector<128x256xf32>
         %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
-     }{xegpu.sg_id_range = array<i32: 1, 2>}
+     }{sg_id_range = #xegpu.range<[1, 2]>}
     %cond6 = arith.cmpi sge, %sg_id, %c2 : index
     %cond7 = arith.cmpi slt, %sg_id, %c31 : index
     %cond8 = arith.andi %cond6, %cond7 : i1
@@ -347,7 +347,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
         : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
         -> vector<128x64xf32>
       %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
-    }{xegpu.sg_id_range = array<i32: 2, 32>}
+    }{sg_id_range = #xegpu.range<[2, 32]>}
     gpu.return
   }
 
@@ -377,7 +377,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
           -> vector<128x64xf32>
         %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
     }
-  } {xegpu.sg_id_range = array<i32: 3, 8>}
+  } {sg_id_range = #xegpu.range<[3, 8]>}
   gpu.return
   }
 }

>From 70fe19cfb4811f4c5619a6f47affc0a5f01998eb Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 17 Jul 2025 00:54:11 +0000
Subject: [PATCH 3/3] Update tests

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 29 ++++-----------------
 1 file changed, 5 insertions(+), 24 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 39cd8c6158685..74c27a87cfb17 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -314,30 +314,11 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
         %load =  xegpu.load_nd %tdesc
           : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
           -> vector<256x128xf32>
-    } {sg_id_range = #xegpu.range<[0, 1]>}
-    %cond3 = arith.cmpi sge, %sg_id, %c1 : index
-    %cond4 = arith.cmpi slt, %sg_id, %c2 : index
+    } {sg_id_range = #xegpu.range<[0, 32]>}
+    %cond3 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c31 : index
     %cond5 = arith.andi %cond3, %cond4 : i1
-     scf.if %cond5 {
-        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-        // CHECK: %[[C1:.*]] = arith.constant 1 : index
-        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
-        %tdesc_a = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
-        %load_a =  xegpu.load_nd %tdesc_a
-          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
-          -> vector<256x128xf32>
-        %tdesc_b = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x256xf32>
-          -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
-        %load_b =  xegpu.load_nd %tdesc_b
-          : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>>
-          -> vector<128x256xf32>
-        %dpas = xegpu.dpas %load_a, %load_b {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], lane_layout = [4, 8], lane_data = [1, 1]>} : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
-     }{sg_id_range = #xegpu.range<[1, 2]>}
-    %cond6 = arith.cmpi sge, %sg_id, %c2 : index
-    %cond7 = arith.cmpi slt, %sg_id, %c31 : index
-    %cond8 = arith.andi %cond6, %cond7 : i1
-    scf.if %cond8 {
+    scf.if %cond5 {
       // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
       // CHECK: %[[C2:.*]] = arith.constant 2 : index
       // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
@@ -377,7 +358,7 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
           -> vector<128x64xf32>
         %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
     }
-  } {sg_id_range = #xegpu.range<[3, 8]>}
+  } {sg_id_range = #xegpu.range<[3, 32]>}
   gpu.return
   }
 }



More information about the Mlir-commits mailing list