[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Feb 7 06:40:13 PST 2025


================
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank > 2)
+    return emitError() << "desc shape rank exceeds 2";
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError() << "outer layout and data mapping must be 1 "
+                              "for 1D tensor";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot map " << tensorShape[i]
+                           << " elements into " << wiLayout[i] << " by "
+                           << wiData[i] << " tiles";
+    }
+
+    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered elements";
+
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+      if (wiData[1] > chunkSize)
----------------
adam-smnk wrote:

Improved scattered verification based on your commit and moved some of these checks into this type verifier.
Please double check if the verification now follows documentation better.

https://github.com/llvm/llvm-project/pull/124548


More information about the Mlir-commits mailing list