[Mlir-commits] [mlir] [MLIR][XeGPU] Extend op definitions to support 3D+: load_nd, store_nd, prefetch_nd (PR #199811)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 26 19:26:30 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

  **Summary**
   Extend xegpu.load_nd, xegpu.store_nd, and xegpu.prefetch_nd operations to support 3D and higher-dimensional tensor descriptors with batch dimensions, enabling batched memory operations for workloads like [4, 8, 16] tensor loads/stores.

  **Changes**
  - Verifiers: Removed rank > 2 checks in LoadNdOp::verify() and StoreNdOp::verify() to allow 3D+ tensor descriptors
  - Documentation: Added comprehensive documentation explaining: Tensor descriptors can be 1D, 2D, 3D, or higher dimensional; Batch dimensions (leading dimensions) are unrolled to unit dimensions during lowering; Operations execute at 2D granularity at subgroup level to match 2D block IO hardware; Examples of 3D   operations
  - Tests: Added unit tests for 3D operations (load_nd_3d, store_nd_3d, prefetch_nd_3d)

---
Full diff: https://github.com/llvm/llvm-project/pull/199811.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+61-12) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (-6) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (-17) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+29) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index fce14999b4011..10baf40e329d1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -226,16 +226,24 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
     It issues an instruction to prefetch a block of data from continuous
     memory regions to each level of the cache based on their cache policy.
 
+    The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+    tensor descriptor has more than 2 dimensions, the leading dimensions are
+    treated as **batch dimensions** that are unrolled to unit dimensions during
+    lowering. At the subgroup level, the prefetch_nd operation executes at 2D
+    granularity to match the 2D block IO hardware support. The number of offset
+    indices must match the rank of the tensor descriptor.
+
     This operation serves as an anchor through which users assign a layout attribute
     to govern computation distribution.
 
     Arguments:
     - `TensorDesc`: A tensor descriptor specifying the base nd-region of
-      memory and tensor tile to be prefetched.
+      memory and tensor tile to be prefetched. Can be 1D, 2D, 3D, or higher
+      dimensional where leading dimensions are batch dimensions.
 
     - `offsets`: index values representing per-dimension offsets from the
       base position encoded in `TensorDesc`. It is encoded via "offsets"
-      and "const_offsets".
+      and "const_offsets". The number of offsets must match the tensor descriptor rank.
 
     - `l1_hint`, `l2_hint`, `l3_hint`: [optional] An cache-hint attribute
       indicating the desired behavior at the L1, L2, and L3 cache levels.
@@ -243,7 +251,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
     - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand.
        Only valid at the workgroup and subgroup levels.
 
-    Example (Workgroup level):
+    Example 1 (Workgroup level, 2D):
     ```mlir
       %c0 = arith.constant 0 : index
       %c1 = arith.constant 1 : index
@@ -254,6 +262,13 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
         : !xegpu.tensor_desc<32x256xf16>
     ```
 
+    Example 2 (3D with batch dimension):
+    ```mlir
+      // Prefetch 4 independent 8x16 blocks
+      xegpu.prefetch_nd %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+        : !xegpu.tensor_desc<4x8x16xf16>
+    ```
+
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -321,6 +336,14 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     hints for each level of cache, L1, L2 and L3. If hardware does not have a
     correspoding cache, Corresponding cache hint attribute will be masked.
 
+    The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+    tensor descriptor has more than 2 dimensions, the leading dimensions are
+    treated as **batch dimensions** that are unrolled to unit dimensions during
+    lowering. At the subgroup level, the load_nd operation executes at 2D
+    granularity to match the 2D block IO hardware support. The result vector
+    has the same shape as the tensor descriptor. The number of offset indices
+    must match the rank of the tensor descriptor.
+
     On Intel GPUs, hardware-supported packing rearranges data elements during
     the load of the B operand when the element bit-width is less than 32 bits
     (for example, fp16). The transpose feature reorders data during the load
@@ -336,10 +359,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     Arguments:
 
     - `TensorDesc`: A tensor descriptor specifying the base nd-region of memory
-      and the tensor tile to be loaded.
+      and the tensor tile to be loaded. Can be 1D, 2D, 3D, or higher dimensional
+      where leading dimensions are batch dimensions.
 
     - `offsets`: Index values representing per-dimension offsets from the base position
       encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`.
+      The number of offsets must match the tensor descriptor rank.
 
     - `packed`: [optional] A unit attribute indicating that packing is applied
       during the load when supported by the hardware. Only valid at lane level.
@@ -352,7 +377,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 
     - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as the result of the load (they are identical). Only valid at workgroup and subgroup levels.
 
-    Example 1 (Workgroup level):
+    Example 1 (Workgroup level, 2D):
     ```mlir
       xegpu.load_nd %1 {transpose = [1, 0],
                         l1_hint = #xegpu.cache_hint<cached>,
@@ -361,13 +386,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
                         layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>}
               : !xegpu.tensor_desc<32x256xf32> -> vector<32x256xf32>
     ```
-    Example 2 (lane level):
+
+    Example 2 (lane level, 2D):
     ```mlir
       xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<uncached>}>
         : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
     ```
 
+    Example 3 (3D with batch dimension):
+    ```mlir
+      // Load 4 independent 8x16 blocks
+      %result = xegpu.load_nd %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+        : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+    ```
 
   }];
 
@@ -436,7 +468,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
   AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>, AnchorLayoutInterface
   ]> {
-  let summary = "stores a n-D block register region back to memory, currently only supports 2D";
+  let summary = "stores a n-D block register region back to memory";
 
   let description = [{
     StoreNdOp essentially mimics the hardware block write instruction io
@@ -444,7 +476,14 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     by the TensorDesc. It takes a set of optional cache hints for each level
     of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
     Corresponding cache hint attribute will be masked.
-    It is only available to 1D or 2D blocked tensor_desc.
+
+    The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+    tensor descriptor has more than 2 dimensions, the leading dimensions are
+    treated as **batch dimensions** that are unrolled to unit dimensions during
+    lowering. At the subgroup level, the store_nd operation executes at 2D
+    granularity to match the 2D block IO hardware support. The value vector
+    must have the same shape as the tensor descriptor. The number of offset
+    indices must match the rank of the tensor descriptor.
 
     At lane level, the input vector represents the data to be stored by each lane.
 
@@ -453,13 +492,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
 
     Arguments:
 
-    - `value`: A vector value representing the tensor tile to be stored.
+    - `value`: A vector value representing the tensor tile to be stored. Can be
+      1D, 2D, 3D, or higher dimensional where leading dimensions are batch dimensions.
 
     - `TensorDesc`: A tensor descriptor specifying the base nd-region of memory and
-      the tensor tile to be stored.
+      the tensor tile to be stored. Can be 1D, 2D, 3D, or higher dimensional
+      where leading dimensions are batch dimensions.
 
     - `offsets`: Index values representing per-dimension offsets from the base position
       encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`.
+      The number of offsets must match the tensor descriptor rank.
 
     - `l1_hint`, `l2_hint`, `l3_hint`: [optional] Cache-hint attributes indicating the
       desired behavior at the L1, L2, and L3 cache levels.
@@ -467,7 +509,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as
       the value to be stored (they are identical). Only valid at workgroup and subgroup levels.
 
-    Example 1 (Workgroup level):
+    Example 1 (Workgroup level, 2D):
     ```mlir
       xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                              l2_hint = #xegpu.cache_hint<write_back>,
@@ -475,7 +517,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                              layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>}
                              : vector<32x256xf16>, !xegpu.tensor_desc<32x256xf16>
     ```
-    Example 2 (lane level):
+
+    Example 2 (lane level, 2D):
     ```mlir
       xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                              l2_hint = #xegpu.cache_hint<write_back>,
@@ -483,6 +526,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                              : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
 
+    Example 3 (3D with batch dimension):
+    ```mlir
+      // Store 4 independent 8x16 blocks
+      xegpu.store_nd %value, %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>}>
+        : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+    ```
 
   }];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 56d340141ee8f..5e726e8f9faed 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -344,9 +344,6 @@ LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
 
-  if (tdescTy.getRank() > 2)
-    return emitOpError("Expects a 1D or 2D TensorDesc.\n");
-
   if (!valueTy)
     return emitOpError("Invalid result, it should be a VectorType.\n");
 
@@ -466,9 +463,6 @@ LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
 
-  if (dstTy.getRank() > 2)
-    return emitOpError("Expects a 1D or 2D TensorDesc.\n");
-
   if (!valTy)
     return emitOpError("Expecting a VectorType result.\n");
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 341427e7e1231..ad216578cb8b9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -108,14 +108,6 @@ func.func @load_nd_vc_4(%src: memref<24x32xf32>) {
   return
 }
 
-// -----
-func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
-  %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
-  // expected-error at +1 {{Expects a 1D or 2D TensorDesc}}
-  %2 = xegpu.load_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
-  return
-}
-
 // -----
 func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
   %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
@@ -176,15 +168,6 @@ func.func @store_nd_vc_3(%dst: memref<24x32xf16>) {
   return
 }
 
-// -----
-func.func @store_nd_vc_4(%dst: memref<8x24x32xf16>) {
-  %1 = arith.constant dense<1.0>: vector<8x24x32xf16>
-  %2 = xegpu.create_nd_tdesc %dst : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16>
-  // expected-error at +1 {{Expects a 1D or 2D TensorDesc}}
-  xegpu.store_nd %1, %2[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16>
-  return
-}
-
 // -----
 func.func @store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) {
   %1 = xegpu.create_nd_tdesc %dst : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 198dc64d2814b..3cc6d9a90893e 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -346,6 +346,35 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @load_nd_3d
+gpu.func @load_nd_3d(%src: memref<4x8x16xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  // CHECK: %{{.*}} = xegpu.load_nd %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+  %2 = xegpu.load_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @store_nd_3d
+gpu.func @store_nd_3d(%dst: memref<4x8x16xf16>) {
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<4x8x16xf16>
+  %val = arith.constant dense<1.0> : vector<4x8x16xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  %1 = xegpu.create_nd_tdesc %dst : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+  xegpu.store_nd %val, %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @prefetch_nd_3d
+gpu.func @prefetch_nd_3d(%src: memref<4x8x16xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+  // CHECK: xegpu.prefetch_nd %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16>
+  xegpu.prefetch_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16>
+  gpu.return
+}
+
 // CHECK: gpu.func @simt_load_4(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
 gpu.func @simt_load_4(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
   // CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>

``````````

</details>


https://github.com/llvm/llvm-project/pull/199811


More information about the Mlir-commits mailing list