[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)

Artem Kroviakov llvmlistbot at llvm.org
Wed Jan 29 03:33:41 PST 2025


================
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank > 2)
+    return emitError() << "desc shape rank exceeds 2";
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError() << "outer layout and data mapping must be 1 "
+                              "for 1D tensor";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot map " << tensorShape[i]
+                           << " elements into " << wiLayout[i] << " by "
+                           << wiData[i] << " tiles";
+    }
+
+    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered elements";
+
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+      if (wiData[1] > chunkSize)
----------------
akroviakov wrote:

`wi_data` describes the tensor's slice shape of the loaded by WI memory chunk. 
While for `_nd` ops the number of chunks can indeed be implicit, for example
`tdesc<<16x16xf16>, #sg_map_a_f16 = xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>`
would mean each work-item loads 8 chunks, each being of shape [2,1].

For a scattered tdesc, this implicit number of chunks is always 1 according to the general tensor shape requirements that do not change in the presence of sg_map:

> The first dimension of the result TensorDesc corresponds to work-items, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1.

Also, `chunk_size` specifies the number of *contiguous* elements one WI loads (i.e., implies WI's slice shape `[1, chunk_size]`), hence `wi_data` for scattered tdesc is required to be `[1, chunk_size]`. 

Do I misinterpret the documentation?

https://github.com/llvm/llvm-project/pull/124548


More information about the Mlir-commits mailing list