[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 24 09:09:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
This PR enables sg-to-wi distribution of xegpu matrix load/store ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/165008.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+116-8)
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+15)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+ constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+ int operandIdx{-1};
+
+ VectorType payloadTy;
+ VectorType warpResultTy;
+ if constexpr (isLoad) {
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ operandIdx = producedByLastLoad->getOperandNumber();
+ payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+ warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ } else {
+ payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ }
+ if (!payloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp,
+ "The matrix op payload has no layouts, using defaults instead.");
+
+ SmallVector<Value> operands;
+ if constexpr (isLoad)
+ operands = {matrixOp.getMemDesc()};
+ else
+ operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ if constexpr (!isLoad)
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+ newOperands[operandIdxToModify] = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+ newWarpOp.getLaneid());
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ if constexpr (isLoad) {
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], newOffsets, newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ } else {
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+ xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ }
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- GpuBarrierDistribution, VectorMultiReductionDistribution,
- LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/regularPatternBenefit);
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+ VectorMultiReductionDistribution, LoadDistribution,
+ StoreDistribution, VectorTransposeDistribution,
+ VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+ MatrixOpDistribution<xegpu::StoreMatrixOp>,
+ MemrefExtractAlignedPointerAsIndexDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/regularPatternBenefit);
patterns.add<VectorShapeCastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/165008
More information about the Mlir-commits
mailing list