[Mlir-commits] [mlir] [MLIR][XeGPU] Refine XeGPU definitions (PR #100763)

Chao Chen llvmlistbot at llvm.org
Fri Jul 26 12:51:25 PDT 2024


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/100763

>From b2e3328ce6cb26755e4a9882f8c10698fdff0372 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 25 Jul 2024 22:41:33 +0000
Subject: [PATCH 1/4] remove 2D support for MaskType

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
 mlir/test/Dialect/XeGPU/invalid.mlir             | 8 ++++++++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index bab0e4afb1e5e..111a270a28b27 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -19,7 +19,7 @@ def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
 def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
-def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1,2], [I1]>, I1]>;
+def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
 def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
 def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7819ad60b97d9..5cad1afb20c06 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -184,4 +184,12 @@ func.func @test_dpas_vc_4(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>, %c : v
   // expected-error at +1 {{Accumulator and Result for dpas op should have the same type}}
   %1 = xegpu.dpas %a, %b, %c : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf16> -> vector<8x16xf32>
   return
+}
+
+// -----
+func.func @test_atomic_rmw(%src: ui64, %value : vector<16x8xf32>, %mask : vector<16xi1>) {
+  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] {chunk_size = 8}: ui64 -> !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>
+  // expected-error at +1 {{failed to verify that all of {tensorDesc, mask, value, result} have same shape}}
+  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16x8xf32> -> vector<16x8xf32>
+  gpu.return
 }
\ No newline at end of file

>From 4d55ae17c1a8abd2b360fa84b08f7254a31dce87 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 26 Jul 2024 14:53:12 +0000
Subject: [PATCH 2/4] - Fix type print format for atomic_rmw - Update LoadNd
 definition   - Support 1D   - Renamed vnni_axis attr to packed - Update Dpas
 definition, limiting A to be 2D and B to be 2D/3D.

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 21 ++++-----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 44 +++++++++++--------
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         | 22 +++++++---
 mlir/test/Dialect/XeGPU/invalid.mlir          | 26 +++--------
 4 files changed, 55 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e477d9a0ca3f1..c50c55060a319 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -245,8 +245,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
 }
 
 
-def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>,
-                                         AllElementCountsMatch<["value", "TensorDesc"]>]> {
+def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "loads a n-D block from memory (represented by TensorDesc)"
                 "to registers (represented by vector)";
   let description = [{
@@ -275,7 +274,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
-                       OptionalAttr<I64Attr>: $vnni_axis,
+                       OptionalAttr<UnitAttr>: $packed,
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -668,14 +667,12 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
     also requires A and B to be loaded with the required data layout. Specially,
-    VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
-    of the corresponding `load_nd` operator. To keep both operands as 3D vector,
-    operand A is loaded via setting `vnni_axis = 1` without impacting the
-    physical layouts change in register. Due to the VNNI transformation, A and B operands
-    are represented as 3D vector, with the last dimension representing the VNNI factor,
-    which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
-    is represented as `A: vector<8x8x2xf16>`, and `B: vector<16x16xf16>` is
-    represented as `B: vector<8x16x2xf16>`.
+
+    VNNI layout is required for B operand. It is achieved via adding `packed`
+    attribute to the `load_nd` operator.  Due to the VNNI transformation, B operands
+    can be represented as a 3D vector, with the last dimension representing the VNNI
+    factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
+    can be represented as `B: vector<8x16x2xf16>`.
 
     Note: on PVC, the hardware can perform load with VNNI transformation when data
           element type is 16-bit or lower precision, taking 2 or 4 elements from
@@ -739,7 +736,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
 
   let assemblyFormat = [{
     $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
-    type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
+    qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
   }];
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5ef47fbbe1ce0..84ccd7f6b4326 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -182,8 +182,8 @@ LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
 
-  if (tdescTy.getRank() != 2)
-    return emitOpError("Expecting a 2D TensorDesc.\n");
+  if (tdescTy.getRank() > 2)
+    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
   if (tdescTy.getScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -206,17 +206,27 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-    if (tdescShape.size() >= trans.size())
+
+    // Make sure the transpose value is valid.
+    bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
+      return t >= 0 && t < tdescTy.getRank();
+    });
+
+    if (valid)
       transpose(trans, tdescShape);
     else
       emitWarning("Invalid transpose attr. It is ignored.");
   }
 
-  if (getVnniAxis()) {
-    auto axis = getVnniAxis().value();
-    auto vnni_factor = valueShape.back();
-    tdescShape[axis] /= vnni_factor;
-    tdescShape.push_back(vnni_factor);
+  if (getPacked()) {
+    if (tdescTy.getRank() == 2) {
+      const int axis = 0;
+      auto vnni_factor = valueShape.back();
+      tdescShape[axis] /= vnni_factor;
+      tdescShape.push_back(vnni_factor);
+    } else {
+      return emitWarning("Invalid Packed Attr. It is ignored (available for 2D TensorDesc only).");
+    }
   }
 
   if (array_len > 1) {
@@ -239,8 +249,8 @@ LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
 
-  if (dstTy.getRank() != 2)
-    return emitOpError("Expecting a 2D TensorDesc.\n");
+  if (dstTy.getRank() > 2)
+    return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
   if (dstTy.getScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
@@ -413,18 +423,14 @@ LogicalResult DpasOp::verify() {
   int64_t lhsRank = getLhsType().getRank();
   int64_t rhsRank = getRhsType().getRank();
 
-  if (lhsRank != rhsRank || lhsRank != 3)
-    return emitOpError(
-        "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
-
-  if (getAcc() && getAccType() != getResultType())
-    return emitOpError("Accumulator and Result for dpas op should have the "
-                       "same type (both shape and element type).");
+  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
+    return emitOpError("expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector.");
 
   auto lhsShape = getLhsType().getShape();
   auto rhsShape = getRhsType().getShape();
-  if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
-    return emitOpError("K-dimension or vnni-factor mismatch.");
+  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
+  if (bK != lhsShape[1])
+    return emitOpError("K-dimension mismatch.");
 
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 00d32d2a2ee94..4f73e41e55370 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -42,8 +42,8 @@ gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
 gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
-  %2 = xegpu.load_nd %1 <{vnni_axis = 0, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
        : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16>
   gpu.return
 }
@@ -121,10 +121,18 @@ gpu.func @test_create_update_tdesc_vc(%src: ui64) {
   gpu.return
 }
 
-// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x8x2xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
-gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
-  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
-  %1 = xegpu.dpas %a, %b: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<16x16xf16>)
+gpu.func @test_dpas_vc(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
+  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @test_dpas_vc_with_packed_b(%[[arg0:.*]]: vector<8x16xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
+gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
+  // CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+  %1 = xegpu.dpas %a, %b: vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   gpu.return
 }
 
@@ -132,7 +140,7 @@ gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
 gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
   %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : <16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   gpu.return
 }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5cad1afb20c06..ff37f5e1cca17 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -159,30 +159,16 @@ func.func @test_store_scatter_vc_2(%src: ui64) {
 }
 
 // -----
-func.func @test_dpas_vc_1(%a : vector<8x4x2xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{K-dimension or vnni-factor mismatch}}
-  %1 = xegpu.dpas %a, %b : vector<8x4x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
+func.func @test_dpas_vc_1(%a : vector<8x8xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{K-dimension mismatch}}
+  %1 = xegpu.dpas %a, %b : vector<8x8xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
 
 // -----
-func.func @test_dpas_vc_2(%a : vector<8x16xf16>, %b: vector<8x16x2xf16>) {
-  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
-  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
-  return
-}
-
-// -----
-func.func @test_dpas_vc_3(%a : vector<8x16xf16>, %b: vector<16x16xf16>) {
-  // expected-error at +1 {{lhs and rhs rank does not match for dpas op, or their rank is not 3}}
-  %1 = xegpu.dpas %a, %b : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  return
-}
-
-// -----
-func.func @test_dpas_vc_4(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>, %c : vector<8x16xf16>) {
-  // expected-error at +1 {{Accumulator and Result for dpas op should have the same type}}
-  %1 = xegpu.dpas %a, %b, %c : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf16> -> vector<8x16xf32>
+func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
+  // expected-error at +1 {{expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector}}
+  %1 = xegpu.dpas %a, %b : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
   return
 }
 

>From c294f67e51388b403a2c47540ebc527d76bdc9e8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 26 Jul 2024 15:45:16 +0000
Subject: [PATCH 3/4] format code

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 84ccd7f6b4326..6de42c6992b82 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -225,7 +225,8 @@ LogicalResult LoadNdOp::verify() {
       tdescShape[axis] /= vnni_factor;
       tdescShape.push_back(vnni_factor);
     } else {
-      return emitWarning("Invalid Packed Attr. It is ignored (available for 2D TensorDesc only).");
+      return emitWarning("Invalid Packed Attr. It is ignored (available for 2D "
+                         "TensorDesc only).");
     }
   }
 
@@ -424,7 +425,8 @@ LogicalResult DpasOp::verify() {
   int64_t rhsRank = getRhsType().getRank();
 
   if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
-    return emitOpError("expecting lhs to be a 2D vector, and rhs to be either 2D or 3D (packed) vector.");
+    return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
+                       "2D or 3D (packed) vector.");
 
   auto lhsShape = getLhsType().getShape();
   auto rhsShape = getRhsType().getShape();

>From 9775c5930eec38e82d17fd5392fc6d57d903e9ab Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 26 Jul 2024 19:51:11 +0000
Subject: [PATCH 4/4] Update CreateNdDesc Op

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 59 +++++++++++--------
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 16 ++---
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         | 27 +++++++++
 mlir/test/Dialect/XeGPU/invalid.mlir          | 11 +++-
 5 files changed, 80 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c50c55060a319..e82bb04203994 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -53,24 +53,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
   let summary = "Create nd-tensor descriptor operation";
   let description = [{
     The "create_nd_tdesc" operation creates a TensorDescType which represents
-    a sub-view of a 2D memory region (It can be extended to support n-D memory
-    region if needed in future). Elements in the subview continuous in each
-    dimension. It encodes the following important information for supporting
-    Intel hardware features:
-
-    * source: an object representing (starting address/pointer of) a 2D memory region.
-        It can be either a 2D memref object, or simply a pointer represented by uint64_t type.
-        for the later case, the shape and layout information of the 2D memory region should
-        be explicitly passed via `shape` and `strides` parameters.
-    * offsets: two index values represents offsets from the "source" at the each dimension
-        at which the subview of the target memory will be created. It is encoded via two
-        variables, including "offsets" and "const_offsets", such that it can
-        accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
-    * shape: the shape information of the memory region pointed by the "source".  It is
-        typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
+    a sub-view of a 1D/2D memory region inside the one or two innermost dimensions
+    of the source. (It can be extended to support n-D memory region if needed in
+    future). Elements in the subview continuous in each dimension. It encodes the
+    following important information for supporting Intel hardware features:
+
+    * source: an object representing (starting address/pointer of) a memory region.
+       It can be either a memref object, or simply a pointer represented by uint64_t type.
+       for the later case, the shape and layout information of the memory region should
+       be explicitly passed via `shape` and `strides` parameters.
+
+    * offsets: index values represents offsets from the "source" at the each dimension
+        at which the subview of the target memory will be created. It is encoded via
+        "offsets" and "const_offsets", such that it can accept various forms, such as,
+        operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
+
+    * shape: the shape information of the memory region pointed by the "source". It is
+         typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
         But if "source" is simply a pointer represented as uint64_t type, or a memref
         type without shape information e.g., memref<?x?xf16>, the shape information has
         to be explicitly passed via the "shape" and "const_shape" arguments.
+
     * strides: the strides of the memory region pointed by the "source". Similar to shape,
         it is typically encoded via the MemRefType of the source too. But if "source" is
         simply a pointer represented as uint64_t type, or a memref type without shape
@@ -78,22 +81,28 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
         passed via the "strides" and "const_strides" argument.
 
     Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
+    ```mlir
     %0 = memref.alloc() : memref<1024x1024xf32>
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
     %1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
+    ```
 
     Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
+    ```mlir
     %0 = memref.alloc(%h, %w) : memref<?x?xf32>
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
     %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
+    ```
 
     Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
+    ```mlir
     %0 = ... : ui64
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
     %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -219,7 +228,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
     memory regions to each level of the cache based on their cache policy.
 
     Example:
-    ```
+    ```mlir
       xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                                 l2_hint = #xegpu.cache_hint<cached>,
                                 l3_hint = #xegpu.cache_hint<cached>}
@@ -262,7 +271,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
     same time.
 
     Example:
-    ```
+    ```mlir
       xegpu.load_nd %1 {transpose = [1, 0],
                         l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<uncached>,
@@ -308,7 +317,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc
     Corresponding cache hint attribute will be masked.
 
     Example:
-    ```
+    ```mlir
       xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                              l2_hint = #xegpu.cache_hint<write_back>,
                              l3_hint = #xegpu.cache_hint<write_through>}
@@ -406,21 +415,21 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
       elements accessed for each offset, default is 1.
 
     Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
-    ```
+    ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
     ```
 
     Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
                It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
-    ```
+    ```mlir
     %0 = memref.alloc() : memref<1024xf32>
     %1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
     ```
 
     Example 3. It is similar to Example 2, but there is some overlaps among workitems.
                It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
-    ```
+    ```mlir
     %0 = memref.alloc() : memref<1024xf32>
     %1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
     ```
@@ -479,7 +488,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
     it works on scattered TensorDesc instead.
 
     Example:
-    ```
+    ```mlir
       xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                              l2_hint = #xegpu.cache_hint<cached>,
                              l3_hint = #xegpu.cache_hint<cached>}
@@ -519,7 +528,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
   Example:
-  ```
+  ```mlir
     %2 = xegpu.load %1, %0 {transpose = [1, 0],
                             l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
@@ -571,7 +580,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
   It has similar semantic to `load_gather`.
 
   Example:
-  ```
+  ```mlir
     %3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                                  l2_hint = #xegpu.cache_hint<write_back>,
                                  l3_hint = #xegpu.cache_hint<write_through>}
@@ -620,7 +629,7 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
     shifts for each work-item.
 
     Example:
-    ```
+    ```mlir
       %2 = xegpu.update_offset %1, [32, 32, 32, 32]
             : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
     ```
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 111a270a28b27..9f101a71697b5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -16,7 +16,7 @@ include "mlir/IR/BuiltinTypes.td"
 def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
-def XeGPU_BaseAddrType: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1, 2]>, UI64, UI32, I64, I32]>;
+def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
 def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
 def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 6de42c6992b82..4a2e73dbae6f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -122,7 +122,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 
 LogicalResult CreateNdDescOp::verify() {
   auto rank = (int64_t)getMixedOffsets().size();
-  bool invalidRank = (rank != 2);
+  bool invalidRank = false;
   bool invalidElemTy = false;
 
   // check source type matches the rank if it is a memref.
@@ -133,17 +133,17 @@ LogicalResult CreateNdDescOp::verify() {
     invalidElemTy |= memrefTy.getElementType() != getElementType();
   }
 
-  // check result type matches the rank
-  invalidRank = (getType().getRank() != rank);
-
   // mismatches among shape, strides, and offsets are
   // already handeled by OffsetSizeAndStrideOpInterface.
   // So they are not check here.
   if (invalidRank)
-    return emitOpError(
-        "Expecting the rank of shape, strides, offsets, "
-        "source memref type (if source is a memref) and TensorDesc "
-        "should match with each other. They currenlty are 2D.");
+    return emitOpError("Expecting the rank of shape, strides, offsets, and source (if source is a memref) should match with each other.");
+
+  // check result TensorDesc rank
+  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
+
+  if (invalidRank)
+    return emitOpError("Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source.");
 
   if (invalidElemTy)
     return emitOpError("TensorDesc should have the same element "
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 4f73e41e55370..35d44cf56a239 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -29,6 +29,13 @@ gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_vc_4(%[[arg0:.*]]: memref<2x24x32xf32>) {
+gpu.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -48,6 +55,15 @@ gpu.func @test_load_nd_vc(%src: memref<8x16xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_2(%[[arg0:.*]]: memref<8x16xf16>) {
+gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -59,6 +75,17 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_vc_2(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) {
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
+  %1 = arith.constant dense<1.0>: vector<32xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ff37f5e1cca17..7ef50bb2b5fad 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -2,7 +2,7 @@
 
 // -----
 func.func @test_create_nd_tdesc_vc_1(%src: memref<24xf32>) {
-  // expected-error at +1 {{Expecting the rank of shape, strides, offsets, source memref type (if source is a memref) and TensorDesc should match with each other. They currenlty are 2D.}}
+  // expected-error at +1 {{Expecting the TensorDesc rank is up to 2 and not greater than the ranks of shape, strides, offsets or the memref source}}
   %1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
   return
 }
@@ -52,6 +52,15 @@ func.func @test_load_nd_vc_2(%src: memref<16xf16>) {
   return
 }
 
+// -----
+func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
+  // expected-warning at +1 {{Invalid Packed Attr.}}
+  %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
+  return
+}
+
 // -----
 func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
   %1 = arith.constant dense<1.0>: vector<24x32xf16>



More information about the Mlir-commits mailing list