[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)

Artem Kroviakov llvmlistbot at llvm.org
Fri Oct 24 09:09:14 PDT 2025


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/165008

This PR enables sg-to-wi distribution of xegpu matrix load/store ops.

>From 887f9781ea3b62cd990d9df7066f28ec049f603b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 24 Oct 2025 16:08:00 +0000
Subject: [PATCH] [MLIR][XeGPU] Matrix load/store subgroup distribution

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 124 ++++++++++++++++--
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  15 +++
 2 files changed, 131 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+    constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+    int operandIdx{-1};
+
+    VectorType payloadTy;
+    VectorType warpResultTy;
+    if constexpr (isLoad) {
+      OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+        return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+      });
+      if (!producedByLastLoad)
+        return rewriter.notifyMatchFailure(
+            warpOp, "The last op is not xegpu::LoadMatrixOp");
+      operandIdx = producedByLastLoad->getOperandNumber();
+      payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+      warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    } else {
+      payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    }
+    if (!payloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp,
+          "The matrix op payload has no layouts, using defaults instead.");
+
+    SmallVector<Value> operands;
+    if constexpr (isLoad)
+      operands = {matrixOp.getMemDesc()};
+    else
+      operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+    if constexpr (!isLoad)
+      operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    rewriter.setInsertionPointAfter(newWarpOp);
+    unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+    newOperands[operandIdxToModify] = arith::AddIOp::create(
+        rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+        newWarpOp.getLaneid());
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    if constexpr (isLoad) {
+      xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+          rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+          newOperands[0], newOffsets, newConstOffsetsAttr,
+          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+      // Resolve the output type and replace all uses.
+      rewriter.replaceAllUsesWith(
+          newWarpOp.getResult(operandIdx),
+          resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+    } else {
+      xegpu::StoreMatrixOp::create(
+          rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+          newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+          xegpu::DistributeLayoutAttr{});
+      rewriter.eraseOp(matrixOp);
+    }
+    return success();
+  }
+};
+
 /// Distribute a scattered load op. The logic and requirements are the same as
 /// for the scattered store distribution. The warpOp's payload vector is
 /// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               GpuBarrierDistribution, VectorMultiReductionDistribution,
-               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
-               VectorBitcastDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution>(
-      patterns.getContext(),
-      /*pattern benefit=*/regularPatternBenefit);
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+           VectorMultiReductionDistribution, LoadDistribution,
+           StoreDistribution, VectorTransposeDistribution,
+           VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+           MatrixOpDistribution<xegpu::StoreMatrixOp>,
+           MemrefExtractAlignedPointerAsIndexDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/regularPatternBenefit);
   patterns.add<VectorShapeCastDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       // Layouts are needed for vector type only.
       if (!isa<VectorType>(operand.get().getType()))
         continue;
+      if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+        continue;
 
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+    gpu.return 
+  }
+}



More information about the Mlir-commits mailing list