[Mlir-commits] [mlir] [MLIR][XeGPU] Updates XeGPU TensorDescAttr and Refine Gather/Scatter definition. (PR #109144)

Adam Siemieniuk llvmlistbot at llvm.org
Fri Sep 20 09:27:27 PDT 2024


================
@@ -293,28 +308,55 @@ LogicalResult UpdateNdOffsetOp::verify() {
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateDescOp
 //===----------------------------------------------------------------------===//
-void CreateDescOp::build(OpBuilder &builder, OperationState &state,
-                         TensorDescType TensorDesc, Value source,
-                         llvm::ArrayRef<OpFoldResult> offsets,
-                         uint32_t chunk_size) {
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<Value> dynamicOffsets;
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
-        chunk_size);
-}
 
 LogicalResult CreateDescOp::verify() {
   auto tdescTy = getTensorDescType();
-  auto chunkSize = getChunkSize();
 
   if (getRankOf(getSource()) > 1)
     return emitOpError(
         "Expecting the source is a 1D memref or pointer (uint64_t).");
 
-  if (!tdescTy.getScattered())
+  if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
+  // Memory space of created TensorDesc should match with the source.
+  // Both source and TensorDesc are considered for global memory by default,
+  // if the memory scope attr is not specified. If source is an integer,
+  // it is considered as ptr to global memory.
+  auto srcMemorySpace = getSourceMemorySpace();
+  auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
+  if (srcMemorySpace != tdescMemorySpace)
+    return emitOpError("Memory space mismatch.")
+           << " Source: " << srcMemorySpace
+           << ", TensorDesc: " << tdescMemorySpace;
+
+  auto chunkSize = tdescTy.getChunkSize();
+
+  // check chunk_size
+  llvm::SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
+                                                    16, 32, 64, 128, 256};
+  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+    return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
+                       "8, 16, 32, 64, 128, or 256.");
+
+  // check total size
+  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
+  auto bitsPerLane = elemBits * chunkSize;
+  if (chunkSize > 1 && bitsPerLane % 32) {
+    // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
+    // For 32-bit data, the hardware can support larger larger chunk size. So
----------------
adam-smnk wrote:

nit: typo

https://github.com/llvm/llvm-project/pull/109144


More information about the Mlir-commits mailing list