[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Feb 7 06:36:51 PST 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/124548

>From d23c97d170efbc1cf9ec5a2e4b10f057d88edb23 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 Jan 2025 10:24:11 +0100
Subject: [PATCH 1/8] [mlir][xegpu] TensorDesc verifier

Adds XeGPU tensor descriptor type verifier.

The checks focus on ensuring that provided subgroup map is valid
with respect to the underlying data.
---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 56 ++++++++++++++++++-
 2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506f..494f11f041b71ff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   }];
 
   let hasCustomAssemblyFormat = true;
-
+  let genVerifyDecl = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c606..42d59da2f7a922a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   if (parser.parseGreater())
     return {};
 
-  return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()),
-                             sg_map.value_or(mlir::Attribute()));
+  return TensorDescType::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), shape, elementType,
+      encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank > 2)
+    return emitError() << "desc shape rank exceeds 2";
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError() << "outer layout and data mapping must be 1 "
+                              "for 1D tensor";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot map " << tensorShape[i]
+                           << " elements into " << wiLayout[i] << " by "
+                           << wiData[i] << " tiles";
+    }
+
+    if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered elements";
+
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+      if (wiData[1] > chunkSize)
+        return emitError()
+               << "too few contiguous elements for work item mapping";
+    }
+  }
+
+  return success();
+}
+
 } // namespace xegpu
 } // namespace mlir
 

>From 5e3e7a80e907c0e4160cc00b76cd3f3f6d6c5e56 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 24 Jan 2025 17:25:29 +0100
Subject: [PATCH 2/8] Test cases

---
 mlir/test/Dialect/XeGPU/invalid.mlir | 79 +++++++++++++++++++++++++++-
 1 file changed, 78 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 201f72120cf2c5d..975b4aea84fe250 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -315,4 +315,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
   // expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{desc shape rank exceeds 2}}
+      !xegpu.tensor_desc<16x2x2xf32>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+      // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+      !xegpu.tensor_desc<4x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{too few contiguous elements for work item mapping}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}

>From 7aba5a69eca63bd4f0924f48df4bbe0f3c5a8439 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 14:40:36 +0100
Subject: [PATCH 3/8] Update load/store verifier + tests

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 26 +++++++------
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir  | 22 +++++++++++
 mlir/test/Dialect/XeGPU/invalid.mlir   | 52 ++++++++++++++++++++++++--
 3 files changed, 86 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index cd883baa986b85e..996fb3638203311 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 //   each dimension.
 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  if (descShape == valShape) {
-    if (!sgMap)
-      return true;
-
-    // this can be relaxed if necessary by supporting non-2d shapes distribution
-    // until the constraints are defined this lives here instead of the tensor
-    // descriptor type.
-    return valShape.size() == sgMap.getWiLayout().size();
-  }
+  // Equal shapes with no distribution - no further verification needed.
+  if (descShape == valShape && !sgMap)
+    return true;
 
+  // Unknown distribution - cannot perform operation on partial shape.
   if (!sgMap)
     return false;
 
-  if (valShape.size() != descShape.size())
+  // Invalid rank or mixed rank usage.
+  size_t descRank = descShape.size();
+  if (descRank > 2 || valShape.size() != descRank)
     return false;
 
+  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+  // Only take the distribution over the innermost dimension for validation.
+  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+  if (descRank == 1)
+    mapLayout = {wiLayout.back()};
+
   for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+       llvm::zip_equal(mapLayout, valShape, descShape)) {
     if (factor * dim != expected)
       return false;
   }
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index dcd6b01974cf306..8af1b600ad0a4e2 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 975b4aea84fe250..5ed93db2be502dd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8x2xf32>
   return
 }
 
 // -----
 func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8xf32>
+  return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
   return
 }
 
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
   return
 }
 
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>

>From fa82d893c6b8ee3ab9ebaa5fd97bf00553f10504 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 27 Jan 2025 15:16:33 +0100
Subject: [PATCH 4/8] Refactor

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 42d59da2f7a922a..ef0ea38027c4503 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -257,7 +257,7 @@ LogicalResult TensorDescType::verify(
                            << wiData[i] << " tiles";
     }
 
-    if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
       auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
       if (wiData[0] != 1)
         return emitError()

>From 5df9cf0089676d2650effacba226ecfdf4fdd523 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 13:53:20 +0100
Subject: [PATCH 5/8] Improve scattered verification

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 56 ++++++++++++++--------
 mlir/test/Dialect/XeGPU/invalid.mlir       | 34 ++++++++-----
 2 files changed, 58 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ef0ea38027c4503..077a924dfad266a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -229,8 +229,24 @@ LogicalResult TensorDescType::verify(
     llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
     mlir::Attribute encoding, mlir::Attribute sg_map) {
   size_t rank = shape.size();
-  if (rank > 2)
-    return emitError() << "desc shape rank exceeds 2";
+  if (rank != 1 && rank != 2)
+    return emitError() << "expected 1D or 2D tensor";
+
+  // Scattered attribute imposes extra restriction on tensor descriptor.
+  // Block attribute can only be validated further against data transfer
+  // operations.
+  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
+  if (scatterAttr) {
+    // Expected tensor ranks for scattered data:
+    //   - 1D tensor for fully non-contiguous elements (chunk size == 1)
+    //   - 2D tensor for scattered blocks (chunk size > 1)
+    IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+    unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+    if (rank == 1 && chunkSize != 1)
+      return emitError() << "expected non-contiguous elements for 1D tensor";
+    if (rank == 2 && chunkSize < 2)
+      return emitError() << "expected chunk blocks for 2D tensor";
+  }
 
   if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
     ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
@@ -238,8 +254,22 @@ LogicalResult TensorDescType::verify(
 
     if (rank == 1) {
       if (wiLayout[0] != 1 || wiData[0] != 1)
-        return emitError() << "outer layout and data mapping must be 1 "
-                              "for 1D tensor";
+        return emitError()
+               << "outer layout distribution and data mapping must be 1 "
+                  "for 1D tensor";
+    }
+
+    if (scatterAttr) {
+      // Validate subgroup mapping rules for scattered tensors.
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered row elements";
+
+      IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+      unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+      if (wiData[1] != chunkSize)
+        return emitError() << "work item data mapping must match the number of "
+                              "contiguous elements";
     }
 
     // For 1D tensor, pad the shape with an outer unit dimension to allow common
@@ -252,21 +282,9 @@ LogicalResult TensorDescType::verify(
     for (size_t i = 0; i < dims; ++i) {
       uint32_t numElemPerWi = wiLayout[i] * wiData[i];
       if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
-        return emitError() << "cannot map " << tensorShape[i]
-                           << " elements into " << wiLayout[i] << " by "
-                           << wiData[i] << " tiles";
-    }
-
-    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
-      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
-      if (wiData[0] != 1)
-        return emitError()
-               << "cannot map over non-contiguous scattered elements";
-
-      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
-      if (wiData[1] > chunkSize)
-        return emitError()
-               << "too few contiguous elements for work item mapping";
+        return emitError() << "cannot distribute " << tensorShape[i] << " over "
+                           << wiLayout[i] << " work items with " << wiData[i]
+                           << " elements each";
     }
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5ed93db2be502dd..733eb1559d6fb14 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -183,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) {
 // -----
 func.func @test_create_tdesc_vc_2(%src: ui64) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
-  // expected-error at +1 {{Incorrect TensorDesc shape}}
   %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
+  // expected-error at +1 {{expected chunk blocks for 2D tensor}}
           -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
   return
 }
@@ -219,7 +219,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_1(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
+  // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   return
 }
@@ -227,7 +227,7 @@ func.func @test_create_tdesc_sg_map_1(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_2(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
+  // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
   return
 }
@@ -235,7 +235,7 @@ func.func @test_create_tdesc_sg_map_2(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_3(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{TensorDesc's chunkSize must match WI's data mapping}}
+  // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
   return
 }
@@ -366,15 +366,23 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
 // -----
 func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{desc shape rank exceeds 2}}
+      // expected-error at +1 {{expected 1D or 2D tensor}}
       !xegpu.tensor_desc<16x2x2xf32>
   return
 }
 
+// -----
+func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{expected 1D or 2D tensor}}
+      !xegpu.tensor_desc<f32>
+  return
+}
+
 // -----
 func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
       !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
   return
 }
@@ -382,7 +390,7 @@ func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
       !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
   return
 }
@@ -390,7 +398,7 @@ func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot map 8 elements into 16 by 1 tiles}}
+      // expected-error at +1 {{cannot distribute 8 over 16 work items with 1 elements each}}
       !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   return
 }
@@ -398,7 +406,7 @@ func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
       !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
   return
 }
@@ -406,7 +414,7 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot map 4 elements into 2 by 4 tiles}}
+      // expected-error at +1 {{cannot distribute 4 over 2 work items with 4 elements each}}
       !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
   return
 }
@@ -414,7 +422,7 @@ func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot map 4 elements into 8 by 1 tiles}}
+      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
       !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
   return
 }
@@ -423,7 +431,7 @@ func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
 func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
-      // expected-error at +1 {{cannot map over non-contiguous scattered elements}}
+      // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
       !xegpu.tensor_desc<4x2xf32,
         #xegpu.scatter_tdesc_attr<chunk_size = 2>,
         #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
@@ -433,7 +441,7 @@ func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
 // -----
 func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
-      // expected-error at +1 {{too few contiguous elements for work item mapping}}
+      // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
       !xegpu.tensor_desc<16xf32,
         #xegpu.scatter_tdesc_attr<chunk_size = 1>,
         #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>

>From f502c66de3e230a364ead78f3d3404f7368ca10d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:01:52 +0100
Subject: [PATCH 6/8] Remove TensorDesc invariant checks from op verifier

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp |  3 +++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp     | 17 +----------------
 2 files changed, 4 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 077a924dfad266a..0f17c6a8fa98d63 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -261,6 +261,9 @@ LogicalResult TensorDescType::verify(
 
     if (scatterAttr) {
       // Validate subgroup mapping rules for scattered tensors.
+      // A work-item's slice of the tensor with shape [sg_size] or
+      // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
+      // the mapping should reflect that.
       if (wiData[0] != 1)
         return emitError()
                << "cannot map over non-contiguous scattered row elements";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 996fb3638203311..476689fae4e254f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -458,22 +458,7 @@ LogicalResult CreateDescOp::verify() {
   if (shape != tdescShape)
     return emitOpError("Incorrect TensorDesc shape. ")
            << "Expected is " << makeString(shape) << "\n";
-  if (auto sgMap = tdescTy.getSGMapAttr()) {
-    // A work-item's slice of the TensorDesc with shape [sg_size] or
-    // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
-    // the mapping should reflect that.
-    if (sgMap.getWiData()[0] > 1)
-      return emitOpError("TensorDesc's SG map only supports multiple elements "
-                         "contiguous along rows.");
-    if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
-      return emitOpError(
-          "TensorDesc's chunkSize must match WI's data mapping.");
-    if (int rank = tdescTy.getRank();
-        (sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
-      return emitOpError("Detected a conflict between SG map's work-item "
-                         "layout and TensorDesc shape. Check the index of "
-                         "`subgroup_size` in WI layout map.");
-  }
+
   return success();
 }
 

>From 50c628323506d2af53cfa8ea62259485cadc348c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:20:45 +0100
Subject: [PATCH 7/8] Add more chunk_size test cases

---
 mlir/test/Dialect/XeGPU/invalid.mlir | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 733eb1559d6fb14..48e8c2808abdae6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -447,3 +447,23 @@ func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<1
         #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
   return
 }
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{expected non-contiguous elements for 1D tensor}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+      !xegpu.tensor_desc<16x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [8, 1], wi_data = [1, 2]>>
+  return
+}

>From a72c12f8f05f8f7456e442f40882f809ac10cdd1 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 7 Feb 2025 14:59:55 +0100
Subject: [PATCH 8/8] Move memory space check to type verifier

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 11 ++++++++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp     |  4 ----
 mlir/test/Dialect/XeGPU/invalid.mlir       |  2 +-
 3 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0f17c6a8fa98d63..becc32d1226973d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -232,9 +232,6 @@ LogicalResult TensorDescType::verify(
   if (rank != 1 && rank != 2)
     return emitError() << "expected 1D or 2D tensor";
 
-  // Scattered attribute imposes extra restriction on tensor descriptor.
-  // Block attribute can only be validated further against data transfer
-  // operations.
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
     // Expected tensor ranks for scattered data:
@@ -248,6 +245,14 @@ LogicalResult TensorDescType::verify(
       return emitError() << "expected chunk blocks for 2D tensor";
   }
 
+  if (auto blockAttr =
+          mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
+    MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
+    if (rank == 2 && memorySpaceAttr &&
+        memorySpaceAttr.getValue() == MemorySpace::SLM)
+      return emitError() << "SLM is not supported for 2D block tensor";
+  }
+
   if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
     ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
     ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 476689fae4e254f..e06d99ac20bb736 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -231,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() {
   if (getType().isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
-  if (getType().getRank() == 2 &&
-      tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
-    return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
-
   return success();
 }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 48e8c2808abdae6..9162e0012f6d56d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
 
 // -----
 func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
-  // expected-error at +1 {{SLM is not supported for 2D Block TensorDesc}}
+  // expected-error at +1 {{SLM is not supported for 2D block tensor}}
   %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
   return
 }



More information about the Mlir-commits mailing list