[Mlir-commits] [mlir] 77600cb - [MLIR][XeGPU] XeGPU Layout adds support for fractional-subgroup-size vector (#183434)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 26 19:49:38 PST 2026


Author: Jianhui Li
Date: 2026-02-26T19:49:33-08:00
New Revision: 77600cbd9798bfde4949d047609866ebb6a7fd25

URL: https://github.com/llvm/llvm-project/commit/77600cbd9798bfde4949d047609866ebb6a7fd25
DIFF: https://github.com/llvm/llvm-project/commit/77600cbd9798bfde4949d047609866ebb6a7fd25.diff

LOG: [MLIR][XeGPU] XeGPU Layout adds support for fractional-subgroup-size vector  (#183434)

This PR enhances the layout assignment for XeGPU load/store operations
to handle vector size smaller than subgroup size.
Say for vector[4], in case of lane_data=[1], lane_layout=[4] and
inst_data=[4].
The fractional-subgroup-size vector support is required to support the
cross-subgroup reduction case. The number of participant subgroups in
reduction can be small, so it causes each subgroup needs to reduce a
small vector size, often a fraction of subgroup size.
Most layout-based subgroup distribution patterns support
fraction-subgroup-size without no change except a few: reduction,
insert/extract, constant. We don't expect ND operations (like
load_nd/store_nd/dpas) accept fractional-subgroup-size vector.

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
    mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
    mlir/test/Dialect/XeGPU/propagate-layout.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7ab1e49ac1d95..a3f0c4285da25 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -681,7 +681,7 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
 static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
     xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
     xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
-    int maxChunkSize, int valShapeSize, int subgroupSize) {
+    int maxChunkSize, ArrayRef<int64_t> resShape, int subgroupSize) {
 
   if (layoutKind == xegpu::LayoutKind::Subgroup)
     return consumerLayout;
@@ -691,24 +691,24 @@ static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
   SmallVector<int64_t> consumerLaneData =
       consumerLayout.getEffectiveLaneDataAsInt();
 
-  SmallVector<int> instData(valShapeSize, 1);
-  SmallVector<int> laneLayout(valShapeSize, 1);
-  SmallVector<int> laneData(valShapeSize, 1);
+  SmallVector<int> instData(resShape.size(), 1);
+  SmallVector<int> laneLayout(resShape.size(), 1);
+  SmallVector<int> laneData(resShape.size(), 1);
 
   if (!isChunkedLoad) {
     if (layoutKind == xegpu::LayoutKind::InstData) {
-      instData[valShapeSize - 1] =
-          std::min(static_cast<int>(consumerInstData[valShapeSize - 1]),
-                   maxChunkSize * subgroupSize);
+      instData.back() = std::min(static_cast<int>(consumerInstData.back()),
+                                 maxChunkSize * subgroupSize);
       return xegpu::LayoutAttr::get(context, instData);
     } else if (layoutKind == xegpu::LayoutKind::Lane) {
-      laneLayout.back() = subgroupSize;
       laneData.back() =
           std::min(static_cast<int>(consumerLaneData.back()), maxChunkSize);
+      laneLayout.back() = std::min(static_cast<int64_t>(subgroupSize),
+                                   resShape.back() / laneData.back());
       return xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
   } else {
-    assert(valShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
+    assert(resShape.size() == 2 && "Chunked Store must access 2D tensor tile.");
     if (layoutKind == xegpu::LayoutKind::InstData) {
       instData[0] = subgroupSize;
       instData[1] =
@@ -730,7 +730,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
 
   const int subgroupSize = uArch->getSubgroupSize();
-  int resShapeSize = resVecTy.getShape().size();
+  ArrayRef<int64_t> resShape = resVecTy.getShape();
   auto context = resVecTy.getContext();
   auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
@@ -740,8 +740,8 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
   int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
 
   return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
-                                      (chunkSize > 1), maxChunkSize,
-                                      resShapeSize, subgroupSize);
+                                      (chunkSize > 1), maxChunkSize, resShape,
+                                      subgroupSize);
 }
 
 /// Sets up the anchor layout for load matrix operation.
@@ -753,7 +753,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
                                    const xegpu::uArch::uArch *uArch) {
 
   const int subgroupSize = uArch->getSubgroupSize();
-  int resShapeSize = resVecTy.getShape().size();
+  ArrayRef<int64_t> resShape = resVecTy.getShape();
   auto context = resVecTy.getContext();
   auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
@@ -762,7 +762,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
           uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
   int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
   return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
-                                      false, maxChunkSize, resShapeSize,
+                                      false, maxChunkSize, resShape,
                                       subgroupSize);
 }
 
@@ -796,10 +796,12 @@ setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
 
   if (!isChunkedStore) {
     if (layoutKind == xegpu::LayoutKind::InstData) {
-      instData[srcShapeSize - 1] = subgroupSize;
+      instData[srcShapeSize - 1] =
+          std::min(subgroupSize, static_cast<int>(srcShape.back()));
       return xegpu::LayoutAttr::get(context, instData);
     } else if (layoutKind == xegpu::LayoutKind::Lane) {
-      laneLayout[srcShapeSize - 1] = subgroupSize;
+      laneLayout[srcShapeSize - 1] =
+          std::min(subgroupSize, static_cast<int>(srcShape.back()));
       return xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
   } else {

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index b6c172ecf4ae0..5dd05e6cb0001 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -41,7 +41,6 @@ func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf
 // CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>
 // CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>>
 gpu.module @test {
-
 func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
@@ -320,3 +319,18 @@ func.func @vector_shape_cast_expand_and_merge(%arg0: memref<256xf16>, %arg1: mem
     return
   }
 }
+
+// -----
+gpu.module @test{
+  // CHECK-LABEL: load_store_matrix
+  // CHECK: xegpu.load_matrix %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 1]>}>
+  // CHECK: xegpu.store_matrix %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 1]>}>
+  func.func @load_store_matrix(%arg0: !xegpu.mem_desc<64x128xf32>, %arg1: i1) {
+    %c0 = arith.constant 0 : index
+    scf.if %arg1 {
+      %0 = xegpu.load_matrix %arg0[%c0, %c0] : !xegpu.mem_desc<64x128xf32>, index, index -> vector<2x1xf32>
+      xegpu.store_matrix %0, %arg0[%c0, %c0] : vector<2x1xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+    }
+    return
+  }
+}

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 17c9ec131ed70..4f2349a89b1ed 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -792,3 +792,17 @@ func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>
 }
 }
 
+// -----
+gpu.module @test{
+  // CHECK-LABEL: load_store_matrix
+  // CHECK: xegpu.load_matrix %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 1], lane_data = [1, 1]>}>
+  // CHECK: xegpu.store_matrix %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 1], lane_data = [1, 1]>}>
+  func.func @load_store_matrix(%arg0: !xegpu.mem_desc<64x128xf32>, %arg1: i1) {
+    %c0 = arith.constant 0 : index
+    scf.if %arg1 {
+      %0 = xegpu.load_matrix %arg0[%c0, %c0] : !xegpu.mem_desc<64x128xf32>, index, index -> vector<2x1xf32>
+      xegpu.store_matrix %0, %arg0[%c0, %c0] : vector<2x1xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+    }
+    return
+  }
+}


        


More information about the Mlir-commits mailing list