[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)
Chao Chen
llvmlistbot at llvm.org
Wed Jul 9 12:47:35 PDT 2025
================
@@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
- auto chunkSize = tdescTy.getChunkSize();
+ auto chunkSize = tdescTy.getChunkSizeAsInt();
if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
<< "Value should have the same element type as TensorDesc.";
- if (tdescShape[0] != maskShape[0])
+ llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
return emitError()
- << "dim-0 of the Mask and TensorDesc should be the same.";
+ << "Mask should match TensorDesc except the chunk size dim.";
----------------
chencha3 wrote:
Thanks, added one for each.
https://github.com/llvm/llvm-project/pull/145916
More information about the Mlir-commits
mailing list