[Mlir-commits] [mlir] [MLIR][XeGPU] Updates XeGPU TensorDescAttr and Refine Gather/Scatter definition. (PR #109144)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Sep 20 09:27:27 PDT 2024
================
@@ -293,28 +308,55 @@ LogicalResult UpdateNdOffsetOp::verify() {
//===----------------------------------------------------------------------===//
// XeGPU_CreateDescOp
//===----------------------------------------------------------------------===//
-void CreateDescOp::build(OpBuilder &builder, OperationState &state,
- TensorDescType TensorDesc, Value source,
- llvm::ArrayRef<OpFoldResult> offsets,
- uint32_t chunk_size) {
- llvm::SmallVector<int64_t> staticOffsets;
- llvm::SmallVector<Value> dynamicOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
- chunk_size);
-}
LogicalResult CreateDescOp::verify() {
auto tdescTy = getTensorDescType();
- auto chunkSize = getChunkSize();
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
- if (!tdescTy.getScattered())
+ if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
+ // Memory space of created TensorDesc should match with the source.
+ // Both source and TensorDesc are considered for global memory by default,
+ // if the memory scope attr is not specified. If source is an integer,
+ // it is considered as ptr to global memory.
+ auto srcMemorySpace = getSourceMemorySpace();
+ auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
+ if (srcMemorySpace != tdescMemorySpace)
+ return emitOpError("Memory space mismatch.")
+ << " Source: " << srcMemorySpace
+ << ", TensorDesc: " << tdescMemorySpace;
+
+ auto chunkSize = tdescTy.getChunkSize();
+
+ // check chunk_size
+ llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
+ 16, 32, 64, 128, 256};
+ if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+ return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
+ "8, 16, 32, 64, 128, or 256.");
+
+ // check total size
+ auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
+ auto bitsPerLane = elemBits * chunkSize;
+ if (chunkSize > 1 && bitsPerLane % 32) {
+ // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
+ // For 32-bit data, the hardware can support larger larger chunk size. So
----------------
adam-smnk wrote:
nit: typo
https://github.com/llvm/llvm-project/pull/109144
More information about the Mlir-commits
mailing list