[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)

Charitha Saumya llvmlistbot at llvm.org
Mon Jun 16 17:02:08 PDT 2025


================
@@ -462,10 +492,32 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+
+    if (originalChunkSize > 1) {
+      targetMaskShape.pop_back();
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          convertedMasks.push_back(mask);
+        }
+      }
+      // This is to handle the transpose effect when chunkSize > 1.
+      if (targetShape && targetShape->size() > 1) {
----------------
charithaintc wrote:

targetShape is already valid. not need to check again.

https://github.com/llvm/llvm-project/pull/144447


More information about the Mlir-commits mailing list