[Mlir-commits] [mlir] [XeGPU] Add sg_map for scatter verification (PR #124300)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Jan 30 04:44:11 PST 2025
================
@@ -551,10 +574,21 @@ LogicalResult StoreScatterOp::verify() {
if (tdescTy.getRank() == 2) {
if (!getTransposeAttr())
- return emitOpError("load_gather has to be transposed.");
+ return emitOpError("Store of a rank-2 tensor has to be transposed.");
transpose({1, 0}, tdescShape);
}
+ if (auto sgMap = tdescTy.getSGMapAttr()) {
+ auto valueVecTy = cast<VectorType>(valueTy);
+ const int32_t wiData =
+ sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
+ if (valueVecTy.getNumElements() != wiData ||
+ valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
+ return emitOpError("Chunk size, vector size and wi_data must match.");
+ }
+ tdescShape[tdescTy.getRank() - 1] = 1;
----------------
adam-smnk wrote:
Could please add some comments?
I follow the verification logic with Xe docs in mind but it's not very intuitive when I forget all the details in a week.
https://github.com/llvm/llvm-project/pull/124300
More information about the Mlir-commits
mailing list