[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)

Jianhui Li llvmlistbot at llvm.org
Tue Jun 17 14:43:19 PDT 2025


================
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
-    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+    // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+    if (originalChunkSize > 1)
+      targetIndiceShape.pop_back();
 
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
     SmallVector<Type> convertedIndiceTypes =
-        getUnrolledTypes(indiceVecTy, *targetShape);
+        getUnrolledTypes(indiceVecTy, targetIndiceShape);
     SmallVector<Value> convertedIndiceVec =
-        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+        pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
 
     SmallVector<Value> newOps;
-    for (auto indice : convertedIndiceVec) {
-      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
-                                                        op.getSource(), indice);
-      newOps.push_back(newOp);
+
+    // more indices is need when chunkSize > 1. Since a big load from one
+    // address could be break into multiple small loads.
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+      for (auto [indice, indiceType] :
+           llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          // Compute the offset
+          Value inc = rewriter.create<arith::ConstantIndexOp>(
+              loc, i * blockedChunkSize);
+          Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
----------------
Jianhui-Li wrote:

added

https://github.com/llvm/llvm-project/pull/144447


More information about the Mlir-commits mailing list