[Mlir-commits] [mlir] [mlir][xegpu] Tensor descriptor type verifier (PR #124548)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Feb 7 06:40:13 PST 2025
================
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
----------------
adam-smnk wrote:
Improved scattered verification based on your commit and moved some of these checks into this type verifier.
Please double check if the verification now follows documentation better.
https://github.com/llvm/llvm-project/pull/124548
More information about the Mlir-commits
mailing list