[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)

Jianhui Li llvmlistbot at llvm.org
Fri Jun 13 14:38:41 PDT 2025


================
@@ -396,11 +396,198 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+      auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy;
+    if (op.getMask())
+      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
----------------
Jianhui-Li wrote:

done

https://github.com/llvm/llvm-project/pull/143602


More information about the Mlir-commits mailing list