[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)

Jianhui Li llvmlistbot at llvm.org
Tue Jun 17 14:46:33 PDT 2025


================
@@ -519,30 +570,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
-    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    SmallVector<Type> convertedValTypes =
-        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
-
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+      convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          convertedMasks.push_back(mask);
+        }
+      }
+      // This is to handle the transpose effect when chunkSize > 1.
+      std::swap((*targetShape)[0], (*targetShape)[1]);
+
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
----------------
Jianhui-Li wrote:

The similarity is at very small scope. Introducing additional abstraction may not worth it. 

https://github.com/llvm/llvm-project/pull/144447


More information about the Mlir-commits mailing list