[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)
Jianhui Li
llvmlistbot at llvm.org
Fri Jun 13 14:38:41 PDT 2025
================
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ SmallVector<Value> convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+ auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+ using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy;
+ if (op.getMask())
+ maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
----------------
Jianhui-Li wrote:
done
https://github.com/llvm/llvm-project/pull/143602
More information about the Mlir-commits
mailing list