[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)
Jianhui Li
llvmlistbot at llvm.org
Wed Jun 11 10:06:07 PDT 2025
================
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
----------------
Jianhui-Li wrote:
my understanding is that pack [m, n] to [m/bm, n/bn, bm, bn] so it is 1 to N. unpack does reverse so it is N to 1.
https://github.com/llvm/llvm-project/pull/143602
More information about the Mlir-commits
mailing list