[Mlir-commits] [mlir] c5c6588 - [MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi (#183179)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 5 09:56:56 PST 2026
Author: Nishant Patel
Date: 2026-03-05T09:56:51-08:00
New Revision: c5c6588c115347cdd000cb8a650b19dca0381196
URL: https://github.com/llvm/llvm-project/commit/c5c6588c115347cdd000cb8a650b19dca0381196
DIFF: https://github.com/llvm/llvm-project/commit/c5c6588c115347cdd000cb8a650b19dca0381196.diff
LOG: [MLIR][XeGPU] Add distribution pattern for xegpu load & store matrix from sg to wi (#183179)
This PR adds distribution pattern for xegpu.load_matrix &
xegpu.store_matrix ops for the new sg-to-wi pass
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index cffd80f8fcf92..5cd766ed2813e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -597,6 +598,137 @@ struct SgToWiMultiDimReduction
}
};
+/// Helper to compute distributed coordinates for matrix ops.
+/// When not using subgroup_block_io, each workitem computes its own
+/// coordinates based on the layout and lane ID.
+static SmallVector<Value> computeDistributedCoordsForMatrixOp(
+ ConversionPatternRewriter &rewriter, Location loc,
+ xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
+ ValueRange origOffsets) {
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+}
+
+/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
+struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
+ using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = op.getLayoutAttr();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ op, "the matrix op payload must be a vector type");
+
+ auto loc = op.getLoc();
+ auto offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "the load op must have offsets");
+
+ FailureOr<VectorType> distPayloadTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, loc, offsets);
+
+ SmallVector<Value> newCoords = offsetsAsValues;
+ if (!op.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordsForMatrixOp(
+ rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+ if (newCoords.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute distributed coordinates.");
+ }
+
+ SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+ auto newOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
+ ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
+ xegpu::DistributeLayoutAttr{});
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
+struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
+ using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto layout = op.getLayoutAttr();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ op, "the matrix op payload must be a vector type");
+
+ auto loc = op.getLoc();
+ auto offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "the store op must have offsets");
+
+ FailureOr<VectorType> distPayloadTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, loc, offsets);
+
+ SmallVector<Value> newCoords = offsetsAsValues;
+ if (!op.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordsForMatrixOp(
+ rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+ if (newCoords.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute distributed coordinates.");
+ }
+
+ SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{},
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
+ distPayloadTyOrFailure.value()),
+ adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
+ op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
/// workitem-level.
///
@@ -901,5 +1033,6 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
- SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
+ SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index bc36554f5c266..d7b4883760c05 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -461,3 +461,67 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
}
+
+// -----
+// load_matrix and store_matrix with coordinate computation (offsets [0,0])
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_1
+// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[R2]], %{{.*}} : index
+// CHECK-DAG: %[[COL:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[COL]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with non-zero offsets [0,1]
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_2
+// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
+// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
+// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
+// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[R2]], %{{.*}} : index
+// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[MUL]], %{{.*}} : index
+// CHECK-DAG: %[[R3:.*]] = arith.remui %[[R1]], %{{.*}} : index
+// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[R3]], %{{.*}} : index
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+}
+}
+
+// -----
+// load_matrix and store_matrix with subgroup_block_io (no coordinate computation)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @load_store_matrix_3
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+ !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+ gpu.return
+}
+}
More information about the Mlir-commits
mailing list