[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)

Chao Chen llvmlistbot at llvm.org
Fri Jun 13 08:00:31 PDT 2025


================
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
----------------
chencha3 wrote:

Here, the targetShape for indices should drop the last dim if chunkSize != 1. 

https://github.com/llvm/llvm-project/pull/143602


More information about the Mlir-commits mailing list