[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)

Jianhui Li llvmlistbot at llvm.org
Wed Jun 11 10:06:07 PDT 2025


================
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
----------------
Jianhui-Li wrote:

my understanding is that pack [m, n] to [m/bm, n/bn, bm, bn] so it is 1 to N. unpack does reverse so it is N to 1. 

https://github.com/llvm/llvm-project/pull/143602


More information about the Mlir-commits mailing list