[Mlir-commits] [mlir] [MLIR][XeGPU] Extend op definitions to support 3D+: load_nd, store_nd, prefetch_nd (PR #199811)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 26 19:26:30 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
**Summary**
Extend xegpu.load_nd, xegpu.store_nd, and xegpu.prefetch_nd operations to support 3D and higher-dimensional tensor descriptors with batch dimensions, enabling batched memory operations for workloads like [4, 8, 16] tensor loads/stores.
**Changes**
- Verifiers: Removed rank > 2 checks in LoadNdOp::verify() and StoreNdOp::verify() to allow 3D+ tensor descriptors
- Documentation: Added comprehensive documentation explaining: Tensor descriptors can be 1D, 2D, 3D, or higher dimensional; Batch dimensions (leading dimensions) are unrolled to unit dimensions during lowering; Operations execute at 2D granularity at subgroup level to match 2D block IO hardware; Examples of 3D operations
- Tests: Added unit tests for 3D operations (load_nd_3d, store_nd_3d, prefetch_nd_3d)
---
Full diff: https://github.com/llvm/llvm-project/pull/199811.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+61-12)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (-6)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (-17)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+29)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index fce14999b4011..10baf40e329d1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -226,16 +226,24 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
It issues an instruction to prefetch a block of data from continuous
memory regions to each level of the cache based on their cache policy.
+ The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+ tensor descriptor has more than 2 dimensions, the leading dimensions are
+ treated as **batch dimensions** that are unrolled to unit dimensions during
+ lowering. At the subgroup level, the prefetch_nd operation executes at 2D
+ granularity to match the 2D block IO hardware support. The number of offset
+ indices must match the rank of the tensor descriptor.
+
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
Arguments:
- `TensorDesc`: A tensor descriptor specifying the base nd-region of
- memory and tensor tile to be prefetched.
+ memory and tensor tile to be prefetched. Can be 1D, 2D, 3D, or higher
+ dimensional where leading dimensions are batch dimensions.
- `offsets`: index values representing per-dimension offsets from the
base position encoded in `TensorDesc`. It is encoded via "offsets"
- and "const_offsets".
+ and "const_offsets". The number of offsets must match the tensor descriptor rank.
- `l1_hint`, `l2_hint`, `l3_hint`: [optional] An cache-hint attribute
indicating the desired behavior at the L1, L2, and L3 cache levels.
@@ -243,7 +251,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
- `layout`: [optional] Describes the expected layout of the `tensor_desc` operand.
Only valid at the workgroup and subgroup levels.
- Example (Workgroup level):
+ Example 1 (Workgroup level, 2D):
```mlir
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -254,6 +262,13 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", [AnchorLayoutInterface]> {
: !xegpu.tensor_desc<32x256xf16>
```
+ Example 2 (3D with batch dimension):
+ ```mlir
+ // Prefetch 4 independent 8x16 blocks
+ xegpu.prefetch_nd %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+ : !xegpu.tensor_desc<4x8x16xf16>
+ ```
+
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -321,6 +336,14 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
hints for each level of cache, L1, L2 and L3. If hardware does not have a
correspoding cache, Corresponding cache hint attribute will be masked.
+ The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+ tensor descriptor has more than 2 dimensions, the leading dimensions are
+ treated as **batch dimensions** that are unrolled to unit dimensions during
+ lowering. At the subgroup level, the load_nd operation executes at 2D
+ granularity to match the 2D block IO hardware support. The result vector
+ has the same shape as the tensor descriptor. The number of offset indices
+ must match the rank of the tensor descriptor.
+
On Intel GPUs, hardware-supported packing rearranges data elements during
the load of the B operand when the element bit-width is less than 32 bits
(for example, fp16). The transpose feature reorders data during the load
@@ -336,10 +359,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
Arguments:
- `TensorDesc`: A tensor descriptor specifying the base nd-region of memory
- and the tensor tile to be loaded.
+ and the tensor tile to be loaded. Can be 1D, 2D, 3D, or higher dimensional
+ where leading dimensions are batch dimensions.
- `offsets`: Index values representing per-dimension offsets from the base position
encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`.
+ The number of offsets must match the tensor descriptor rank.
- `packed`: [optional] A unit attribute indicating that packing is applied
during the load when supported by the hardware. Only valid at lane level.
@@ -352,7 +377,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
- `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as the result of the load (they are identical). Only valid at workgroup and subgroup levels.
- Example 1 (Workgroup level):
+ Example 1 (Workgroup level, 2D):
```mlir
xegpu.load_nd %1 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
@@ -361,13 +386,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>}
: !xegpu.tensor_desc<32x256xf32> -> vector<32x256xf32>
```
- Example 2 (lane level):
+
+ Example 2 (lane level, 2D):
```mlir
xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
```
+ Example 3 (3D with batch dimension):
+ ```mlir
+ // Load 4 independent 8x16 blocks
+ %result = xegpu.load_nd %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+ : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+ ```
}];
@@ -436,7 +468,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>, AnchorLayoutInterface
]> {
- let summary = "stores a n-D block register region back to memory, currently only supports 2D";
+ let summary = "stores a n-D block register region back to memory";
let description = [{
StoreNdOp essentially mimics the hardware block write instruction io
@@ -444,7 +476,14 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
by the TensorDesc. It takes a set of optional cache hints for each level
of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
Corresponding cache hint attribute will be masked.
- It is only available to 1D or 2D blocked tensor_desc.
+
+ The tensor descriptor can be 1D, 2D, 3D, or higher dimensional. When the
+ tensor descriptor has more than 2 dimensions, the leading dimensions are
+ treated as **batch dimensions** that are unrolled to unit dimensions during
+ lowering. At the subgroup level, the store_nd operation executes at 2D
+ granularity to match the 2D block IO hardware support. The value vector
+ must have the same shape as the tensor descriptor. The number of offset
+ indices must match the rank of the tensor descriptor.
At lane level, the input vector represents the data to be stored by each lane.
@@ -453,13 +492,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
Arguments:
- - `value`: A vector value representing the tensor tile to be stored.
+ - `value`: A vector value representing the tensor tile to be stored. Can be
+ 1D, 2D, 3D, or higher dimensional where leading dimensions are batch dimensions.
- `TensorDesc`: A tensor descriptor specifying the base nd-region of memory and
- the tensor tile to be stored.
+ the tensor tile to be stored. Can be 1D, 2D, 3D, or higher dimensional
+ where leading dimensions are batch dimensions.
- `offsets`: Index values representing per-dimension offsets from the base position
encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`.
+ The number of offsets must match the tensor descriptor rank.
- `l1_hint`, `l2_hint`, `l3_hint`: [optional] Cache-hint attributes indicating the
desired behavior at the L1, L2, and L3 cache levels.
@@ -467,7 +509,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
- `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as
the value to be stored (they are identical). Only valid at workgroup and subgroup levels.
- Example 1 (Workgroup level):
+ Example 1 (Workgroup level, 2D):
```mlir
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
@@ -475,7 +517,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>}
: vector<32x256xf16>, !xegpu.tensor_desc<32x256xf16>
```
- Example 2 (lane level):
+
+ Example 2 (lane level, 2D):
```mlir
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
@@ -483,6 +526,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
```
+ Example 3 (3D with batch dimension):
+ ```mlir
+ // Store 4 independent 8x16 blocks
+ xegpu.store_nd %value, %tdesc[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>}>
+ : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+ ```
}];
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 56d340141ee8f..5e726e8f9faed 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -344,9 +344,6 @@ LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
- if (tdescTy.getRank() > 2)
- return emitOpError("Expects a 1D or 2D TensorDesc.\n");
-
if (!valueTy)
return emitOpError("Invalid result, it should be a VectorType.\n");
@@ -466,9 +463,6 @@ LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
- if (dstTy.getRank() > 2)
- return emitOpError("Expects a 1D or 2D TensorDesc.\n");
-
if (!valTy)
return emitOpError("Expecting a VectorType result.\n");
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 341427e7e1231..ad216578cb8b9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -108,14 +108,6 @@ func.func @load_nd_vc_4(%src: memref<24x32xf32>) {
return
}
-// -----
-func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
- %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
- // expected-error at +1 {{Expects a 1D or 2D TensorDesc}}
- %2 = xegpu.load_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
- return
-}
-
// -----
func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
%1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
@@ -176,15 +168,6 @@ func.func @store_nd_vc_3(%dst: memref<24x32xf16>) {
return
}
-// -----
-func.func @store_nd_vc_4(%dst: memref<8x24x32xf16>) {
- %1 = arith.constant dense<1.0>: vector<8x24x32xf16>
- %2 = xegpu.create_nd_tdesc %dst : memref<8x24x32xf16> -> !xegpu.tensor_desc<8x24x32xf16>
- // expected-error at +1 {{Expects a 1D or 2D TensorDesc}}
- xegpu.store_nd %1, %2[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x24x32xf16>, !xegpu.tensor_desc<8x24x32xf16>
- return
-}
-
// -----
func.func @store_nd_simt(%dst: memref<24x32xf32>, %data: vector<3xf32>) {
%1 = xegpu.create_nd_tdesc %dst : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 198dc64d2814b..3cc6d9a90893e 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -346,6 +346,35 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK-LABEL: gpu.func @load_nd_3d
+gpu.func @load_nd_3d(%src: memref<4x8x16xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ // CHECK: %{{.*}} = xegpu.load_nd %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+ %2 = xegpu.load_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16> -> vector<4x8x16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @store_nd_3d
+gpu.func @store_nd_3d(%dst: memref<4x8x16xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<4x8x16xf16>
+ %val = arith.constant dense<1.0> : vector<4x8x16xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %dst : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+ xegpu.store_nd %val, %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x8x16xf16>, !xegpu.tensor_desc<4x8x16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @prefetch_nd_3d
+gpu.func @prefetch_nd_3d(%src: memref<4x8x16xf16>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src : memref<4x8x16xf16> -> !xegpu.tensor_desc<4x8x16xf16>
+ // CHECK: xegpu.prefetch_nd %[[R0]][0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16>
+ xegpu.prefetch_nd %1[0, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8x16xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @simt_load_4(%[[arg0:.*]]: memref<256xf16>, %[[arg1:.*]]: vector<1xindex>, %[[arg2:.*]]: vector<1xi1>) {
gpu.func @simt_load_4(%arg0: memref<256xf16>, %arg1: vector<1xindex>, %arg2: vector<1xi1>) {
// CHECK: %0 = xegpu.load %[[arg0]][%[[arg1]]], %[[arg2]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
``````````
</details>
https://github.com/llvm/llvm-project/pull/199811
More information about the Mlir-commits
mailing list