[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)

Jianhui Li llvmlistbot at llvm.org
Tue Oct 28 12:18:33 PDT 2025


================
@@ -906,6 +907,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
+    PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+    Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+  SmallVector<Value> newOffsets;
+  auto maybeDescOffsets =
+      layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeDescOffsets))
+    return {};
+  assert(maybeDescOffsets.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  newOffsets = llvm::to_vector(llvm::map_range(
+      ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+  return newOffsets;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+
+    OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+      return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+    });
+    if (!producedByLastLoad)
+      return rewriter.notifyMatchFailure(
+          warpOp, "The last op is not xegpu::LoadMatrixOp");
+    const int operandIdx = producedByLastLoad->getOperandNumber();
+
+    VectorType sgPayloadTy =
+        dyn_cast<VectorType>(matrixOp.getResult().getType());
+    VectorType warpResultTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
----------------
Jianhui-Li wrote:

Can you add one verifier here. For operation without subgroup_block_io, the lane_data must be physically contiguous in the slm.

```mlir
// this is correct
 %1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, index, index -> vector<2x16xf32>

// this is not correct. 
 %1 = xegpu.load_matrix %arg0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x16xf32>

```

For operation with subgroup_block_io, the lane_data must be [1, 1]. 


https://github.com/llvm/llvm-project/pull/165008


More information about the Mlir-commits mailing list