[Mlir-commits] [mlir] [mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (PR #127920)
Chao Chen
llvmlistbot at llvm.org
Thu Feb 20 08:29:42 PST 2025
================
@@ -307,6 +310,85 @@ LogicalResult TensorDescType::verify(
return success();
}
+// If tensor descriptor has a sg_map attribute it is used in SIMT mode.
+// In this mode, the distributed vector shape is determined as follows:
+// Definitions:
+// wi_data_size = wi_data[0] × wi_data[1]
+// subgroup_size = wi_layout[0] × wi_layout[1]
+// distribution_unit_size = subgroup_size × wi_data_size
+// ---------------------------------------------------------------------
+// Case 1: Regular loads/stores.
+// ---------------------------------------------------------------------
+// Distributed vector shape must be:
+// [chunk_size / wi_data_size, wi_data_size]
+// If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
+// [wi_data_size]
+// ---------------------------------------------------------------------
+// Case 2: Block loads/stores
+// ---------------------------------------------------------------------
+// Additionalm definitions:
+// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
+// n_distribution_units = tensor_size / distribution_unit_size
+// Given above definitions, the following conditions must be met:
+// * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
+// * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
+// Distributed vector shape must be:
+// [n_distribution_units, wi_data_size]
+FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
+ auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+ // If no sg_map is provided, tensor desc is not used in SIMT mode.
+ if (!sgMap)
+ return failure();
+
+ SmallVector<int64_t> wiData(sgMap.getWiData());
+ SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
+ auto tdescShape = getShape();
+
+ auto wiDataSize = 1, sgSize = 1;
+ for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
+ wiDataSize *= wiDataDim;
+ sgSize *= wiDim;
+ }
+
+ // Case 1: regular loads/stores
+ auto scatterAttr =
+ llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+ if (scatterAttr) {
+ auto chunkSize = scatterAttr.getChunkSize().getInt();
+ // Check if the first dimension of the tensor descriptor shape is
+ // distributable.
+ if (tdescShape[0] % (wiLayout[0]) != 0)
+ return failure();
+ if (chunkSize > 1)
+ return VectorType::get({chunkSize / wiDataSize, wiDataSize},
+ getElementType());
+ return VectorType::get({wiDataSize}, getElementType());
+ }
+
+ // Case 2: block loads/stores
+ // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
+ // and wiLayout must be 1.
+ if (tdescShape.size() == 1) {
+ if (wiData[0] != 1 || wiLayout[0] != 1)
+ return failure();
+ wiData = {wiData[1]};
+ wiLayout = {wiLayout[1]};
+ }
+ // Check if the tensor descriptor shape is distributable.
+ int64_t tensorSize = 1;
+ for (auto [tdescDim, wiDim, wiDataDim] :
+ llvm::zip_equal(tdescShape, wiLayout, wiData)) {
+ if (tdescDim % (wiDim * wiDataDim) != 0)
----------------
chencha3 wrote:
Same here.
https://github.com/llvm/llvm-project/pull/127920
More information about the Mlir-commits
mailing list