[Mlir-commits] [mlir] 8a03658 - [mlir][xegpu] Tensor descriptor type verifier (#124548)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 7 11:43:08 PST 2025


Author: Adam Siemieniuk
Date: 2025-02-07T20:43:05+01:00
New Revision: 8a03658d575b5cfd65abb5cd4e80d0ee4163fc11

URL: https://github.com/llvm/llvm-project/commit/8a03658d575b5cfd65abb5cd4e80d0ee4163fc11
DIFF: https://github.com/llvm/llvm-project/commit/8a03658d575b5cfd65abb5cd4e80d0ee4163fc11.diff

LOG: [mlir][xegpu] Tensor descriptor type verifier (#124548)

Adds XeGPU tensor descriptor type verifier.

The type verifier covers general tensor descriptor invariants w.r.t. Xe
ISA semantics.
Related operation verifiers are updated to account for the new
descriptor checks and avoid duplication.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/test/Dialect/XeGPU/XeGPUOps.mlir
    mlir/test/Dialect/XeGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506f..494f11f041b71ff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   }];
 
   let hasCustomAssemblyFormat = true;
-
+  let genVerifyDecl = 1;
 }
 
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c606..becc32d1226973d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   if (parser.parseGreater())
     return {};
 
-  return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()),
-                             sg_map.value_or(mlir::Attribute()));
+  return TensorDescType::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), shape, elementType,
+      encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,81 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank != 1 && rank != 2)
+    return emitError() << "expected 1D or 2D tensor";
+
+  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
+  if (scatterAttr) {
+    // Expected tensor ranks for scattered data:
+    //   - 1D tensor for fully non-contiguous elements (chunk size == 1)
+    //   - 2D tensor for scattered blocks (chunk size > 1)
+    IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+    unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+    if (rank == 1 && chunkSize != 1)
+      return emitError() << "expected non-contiguous elements for 1D tensor";
+    if (rank == 2 && chunkSize < 2)
+      return emitError() << "expected chunk blocks for 2D tensor";
+  }
+
+  if (auto blockAttr =
+          mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
+    MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
+    if (rank == 2 && memorySpaceAttr &&
+        memorySpaceAttr.getValue() == MemorySpace::SLM)
+      return emitError() << "SLM is not supported for 2D block tensor";
+  }
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError()
+               << "outer layout distribution and data mapping must be 1 "
+                  "for 1D tensor";
+    }
+
+    if (scatterAttr) {
+      // Validate subgroup mapping rules for scattered tensors.
+      // A work-item's slice of the tensor with shape [sg_size] or
+      // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
+      // the mapping should reflect that.
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered row elements";
+
+      IntegerAttr chunkAttr = scatterAttr.getChunkSize();
+      unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1;
+      if (wiData[1] != chunkSize)
+        return emitError() << "work item data mapping must match the number of "
+                              "contiguous elements";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot distribute " << tensorShape[i] << " over "
+                           << wiLayout[i] << " work items with " << wiData[i]
+                           << " elements each";
+    }
+  }
+
+  return success();
+}
+
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index cd883baa986b85e..e06d99ac20bb736 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 //   each dimension.
 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  if (descShape == valShape) {
-    if (!sgMap)
-      return true;
-
-    // this can be relaxed if necessary by supporting non-2d shapes distribution
-    // until the constraints are defined this lives here instead of the tensor
-    // descriptor type.
-    return valShape.size() == sgMap.getWiLayout().size();
-  }
+  // Equal shapes with no distribution - no further verification needed.
+  if (descShape == valShape && !sgMap)
+    return true;
 
+  // Unknown distribution - cannot perform operation on partial shape.
   if (!sgMap)
     return false;
 
-  if (valShape.size() != descShape.size())
+  // Invalid rank or mixed rank usage.
+  size_t descRank = descShape.size();
+  if (descRank > 2 || valShape.size() != descRank)
     return false;
 
+  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+  // Only take the distribution over the innermost dimension for validation.
+  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+  if (descRank == 1)
+    mapLayout = {wiLayout.back()};
+
   for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+       llvm::zip_equal(mapLayout, valShape, descShape)) {
     if (factor * dim != expected)
       return false;
   }
@@ -227,10 +231,6 @@ LogicalResult CreateNdDescOp::verify() {
   if (getType().isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
-  if (getType().getRank() == 2 &&
-      tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
-    return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
-
   return success();
 }
 
@@ -454,22 +454,7 @@ LogicalResult CreateDescOp::verify() {
   if (shape != tdescShape)
     return emitOpError("Incorrect TensorDesc shape. ")
            << "Expected is " << makeString(shape) << "\n";
-  if (auto sgMap = tdescTy.getSGMapAttr()) {
-    // A work-item's slice of the TensorDesc with shape [sg_size] or
-    // [sg_size, chunk_size] will be [1] or [1, chunks_size] respectively,
-    // the mapping should reflect that.
-    if (sgMap.getWiData()[0] > 1)
-      return emitOpError("TensorDesc's SG map only supports multiple elements "
-                         "contiguous along rows.");
-    if (chunkSize != static_cast<int>(sgMap.getWiData()[1]))
-      return emitOpError(
-          "TensorDesc's chunkSize must match WI's data mapping.");
-    if (int rank = tdescTy.getRank();
-        (sgMap.getWiLayout()[2 - rank] != tdescShape[0]))
-      return emitOpError("Detected a conflict between SG map's work-item "
-                         "layout and TensorDesc shape. Check the index of "
-                         "`subgroup_size` in WI layout map.");
-  }
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index dcd6b01974cf306..8af1b600ad0a4e2 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 201f72120cf2c5d..9162e0012f6d56d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -17,7 +17,7 @@ func.func @test_create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
 
 // -----
 func.func @test_create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
-  // expected-error at +1 {{SLM is not supported for 2D Block TensorDesc}}
+  // expected-error at +1 {{SLM is not supported for 2D block tensor}}
   %1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
   return
 }
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8x2xf32>
   return
 }
 
 // -----
 func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8xf32>
+  return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
   return
 }
 
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
   return
 }
 
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error at +1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -137,8 +183,8 @@ func.func @test_create_tdesc_vc_1(%src: ui64) {
 // -----
 func.func @test_create_tdesc_vc_2(%src: ui64) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
-  // expected-error at +1 {{Incorrect TensorDesc shape}}
   %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
+  // expected-error at +1 {{expected chunk blocks for 2D tensor}}
           -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
   return
 }
@@ -173,7 +219,7 @@ func.func @test_prefetch_vc_2(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_1(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{Detected a conflict between SG map's work-item layout and TensorDesc shape. Check the index of `subgroup_size` in WI layout map}}
+  // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
   return
 }
@@ -181,7 +227,7 @@ func.func @test_create_tdesc_sg_map_1(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_2(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{TensorDesc's SG map only supports multiple elements contiguous along rows}}
+  // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [1, 4], wi_data = [2, 1]>>
   return
 }
@@ -189,7 +235,7 @@ func.func @test_create_tdesc_sg_map_2(%src: ui64) {
 // -----
 func.func @test_create_tdesc_sg_map_3(%src: ui64) {
   %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error at +1 {{TensorDesc's chunkSize must match WI's data mapping}}
+  // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
   %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x3xf32, #xegpu.scatter_tdesc_attr<chunk_size = 3>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 2]>>
   return
 }
@@ -315,4 +361,109 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
   // expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{expected 1D or 2D tensor}}
+      !xegpu.tensor_desc<16x2x2xf32>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_rank_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{expected 1D or 2D tensor}}
+      !xegpu.tensor_desc<f32>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{outer layout distribution and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot distribute 8 over 16 work items with 1 elements each}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot distribute 4 over 2 work items with 4 elements each}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+      // expected-error at +1 {{cannot map over non-contiguous scattered row elements}}
+      !xegpu.tensor_desc<4x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{work item data mapping must match the number of contiguous elements}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_1D(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{expected non-contiguous elements for 1D tensor}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_chunk_size_2D(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error at +1 {{expected chunk blocks for 2D tensor}}
+      !xegpu.tensor_desc<16x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [8, 1], wi_data = [1, 2]>>
+  return
+}


        


More information about the Mlir-commits mailing list