[Mlir-commits] [mlir] 77600cb - [MLIR][XeGPU] XeGPU Layout adds support for fractional-subgroup-size vector (#183434)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 19:49:38 PST 2026
Author: Jianhui Li
Date: 2026-02-26T19:49:33-08:00
New Revision: 77600cbd9798bfde4949d047609866ebb6a7fd25
URL: https://github.com/llvm/llvm-project/commit/77600cbd9798bfde4949d047609866ebb6a7fd25
DIFF: https://github.com/llvm/llvm-project/commit/77600cbd9798bfde4949d047609866ebb6a7fd25.diff
LOG: [MLIR][XeGPU] XeGPU Layout adds support for fractional-subgroup-size vector (#183434)
This PR enhances the layout assignment for XeGPU load/store operations
to handle vector size smaller than subgroup size.
Say for vector[4], in case of lane_data=[1], lane_layout=[4] and
inst_data=[4].
The fractional-subgroup-size vector support is required to support the
cross-subgroup reduction case. The number of participant subgroups in
reduction can be small, so it causes each subgroup needs to reduce a
small vector size, often a fraction of subgroup size.
Most layout-based subgroup distribution patterns support
fraction-subgroup-size without no change except a few: reduction,
insert/extract, constant. We don't expect ND operations (like
load_nd/store_nd/dpas) accept fractional-subgroup-size vector.
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
mlir/test/Dialect/XeGPU/propagate-layout.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7ab1e49ac1d95..a3f0c4285da25 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -681,7 +681,7 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
- int maxChunkSize, int valShapeSize, int subgroupSize) {
+ int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
if (layoutKind == xegpu::LayoutKind::Subgroup)
return consumerLayout;
@@ -691,24 +691,24 @@ static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
SmallVector<int64_t> consumerLaneData =
consumerLayout.getEffectiveLaneDataAsInt();
- SmallVector<int> instData(valShapeSize, 1);
- SmallVector<int> laneLayout(valShapeSize, 1);
- SmallVector<int> laneData(valShapeSize, 1);
+ SmallVector<int> instData(resShape.size(), 1);
+ SmallVector<int> laneLayout(resShape.size(), 1);
+ SmallVector<int> laneData(resShape.size(), 1);
if (!isChunkedLoad) {
if (layoutKind == xegpu::LayoutKind::InstData) {
- instData[valShapeSize - 1] =
- std::min(static_cast<int>(consumerInstData[valShapeSize - 1]),
- maxChunkSize * subgroupSize);
+ instData.back() = std::min(static_cast<int>(consumerInstData.back()),
+ maxChunkSize * subgroupSize);
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
- laneLayout.back() = subgroupSize;
laneData.back() =
std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
+ laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
+ resShape.back() / laneData.back());
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
} else {
- assert(valShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
+ assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
if (layoutKind == xegpu::LayoutKind::InstData) {
instData[0] = subgroupSize;
instData[1] =
@@ -730,7 +730,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
- int resShapeSize = resVecTy.getShape().size();
+ ArrayRef<int64_t> resShape = resVecTy.getShape();
auto context = resVecTy.getContext();
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
@@ -740,8 +740,8 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
- (chunkSize > 1), maxChunkSize,
- resShapeSize, subgroupSize);
+ (chunkSize > 1), maxChunkSize, resShape,
+ subgroupSize);
}
/// Sets up the anchor layout for load matrix operation.
@@ -753,7 +753,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
const xegpu::uArch::uArch *uArch) {
const int subgroupSize = uArch->getSubgroupSize();
- int resShapeSize = resVecTy.getShape().size();
+ ArrayRef<int64_t> resShape = resVecTy.getShape();
auto context = resVecTy.getContext();
auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
@@ -762,7 +762,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
- false, maxChunkSize, resShapeSize,
+ false, maxChunkSize, resShape,
subgroupSize);
}
@@ -796,10 +796,12 @@ setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
if (!isChunkedStore) {
if (layoutKind == xegpu::LayoutKind::InstData) {
- instData[srcShapeSize - 1] = subgroupSize;
+ instData[srcShapeSize - 1] =
+ std::min(subgroupSize, static_cast<int>(srcShape.back()));
return xegpu::LayoutAttr::get(context, instData);
} else if (layoutKind == xegpu::LayoutKind::Lane) {
- laneLayout[srcShapeSize - 1] = subgroupSize;
+ laneLayout[srcShapeSize - 1] =
+ std::min(subgroupSize, static_cast<int>(srcShape.back()));
return xegpu::LayoutAttr::get(context, laneLayout, laneData);
}
} else {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index b6c172ecf4ae0..5dd05e6cb0001 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -41,7 +41,6 @@ func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf
// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>
// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.module @test {
-
func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -320,3 +319,18 @@ func.func @vector_shape_cast_expand_and_merge(%arg0: memref<256xf16>, %arg1: mem
return
}
}
+
+// -----
+gpu.module @test{
+ // CHECK-LABEL: load_store_matrix
+ // CHECK: xegpu.load_matrix %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 1]>}>
+ // CHECK: xegpu.store_matrix %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 1]>}>
+ func.func @load_store_matrix(%arg0: !xegpu.mem_desc<64x128xf32>, %arg1: i1) {
+ %c0 = arith.constant 0 : index
+ scf.if %arg1 {
+ %0 = xegpu.load_matrix %arg0[%c0, %c0] : !xegpu.mem_desc<64x128xf32>, index, index -> vector<2x1xf32>
+ xegpu.store_matrix %0, %arg0[%c0, %c0] : vector<2x1xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+ }
+ return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 17c9ec131ed70..4f2349a89b1ed 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -792,3 +792,17 @@ func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>
}
}
+// -----
+gpu.module @test{
+ // CHECK-LABEL: load_store_matrix
+ // CHECK: xegpu.load_matrix %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 1], lane_data = [1, 1]>}>
+ // CHECK: xegpu.store_matrix %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 1], lane_data = [1, 1]>}>
+ func.func @load_store_matrix(%arg0: !xegpu.mem_desc<64x128xf32>, %arg1: i1) {
+ %c0 = arith.constant 0 : index
+ scf.if %arg1 {
+ %0 = xegpu.load_matrix %arg0[%c0, %c0] : !xegpu.mem_desc<64x128xf32>, index, index -> vector<2x1xf32>
+ xegpu.store_matrix %0, %arg0[%c0, %c0] : vector<2x1xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+ }
+ return
+ }
+}
More information about the Mlir-commits
mailing list