[Mlir-commits] [mlir] [mlir][XeGPU] add unroll patterns for load_matrix and store_matrix (PR #154637)

Jianhui Li llvmlistbot at llvm.org
Tue Sep 2 11:05:26 PDT 2025


================
@@ -682,13 +682,90 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
   }
 };
 
+struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
+  using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getType();
+    Type elemTy = valueTy.getElementType();
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
+
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(shape, *targetShape)) {
+      auto adds = xegpu::addWithRightAligned(
+          rewriter, loc, mixedOffsets,
+          getAsIndexOpFoldResult(op.getContext(), offsets));
+      offsetsList.push_back(adds);
+    }
+
+    SmallVector<Value> newOps;
+    for (SmallVector<OpFoldResult> offsets : offsetsList) {
+      auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+          op.getLoc(), newValueTy, op.getMemDesc(), offsets,
+          layout.dropInstData());
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
+  using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getData().getType();
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
+
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(shape, *targetShape)) {
+      auto adds = xegpu::addWithRightAligned(
----------------
Jianhui-Li wrote:

I don't think we need to use addWithRightAligned here? The op's offsets should have always the same number as the distributed offsets (out from shape/targetshape). 

https://github.com/llvm/llvm-project/pull/154637


More information about the Mlir-commits mailing list