[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)

Chao Chen llvmlistbot at llvm.org
Wed Jul 9 12:47:35 PDT 2025


================
@@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   auto maskShape = getShapeOf(maskTy);
   auto valueShape = getShapeOf(valueTy);
   auto tdescShape = getShapeOf(tdescTy);
-  auto chunkSize = tdescTy.getChunkSize();
+  auto chunkSize = tdescTy.getChunkSizeAsInt();
 
   if (valueTy.getElementType() != tdescTy.getElementType())
     return emitError()
            << "Value should have the same element type as TensorDesc.";
 
-  if (tdescShape[0] != maskShape[0])
+  llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
+  if (chunkSize > 1)
+    expectedMaskShape.pop_back();
+  if (expectedMaskShape != maskShape)
     return emitError()
-           << "dim-0 of the Mask and TensorDesc should be the same.";
+           << "Mask should match TensorDesc except the chunk size dim.";
----------------
chencha3 wrote:

Thanks, added one for each. 

https://github.com/llvm/llvm-project/pull/145916


More information about the Mlir-commits mailing list