[Mlir-commits] [mlir] [MLIR][XeGPU] Add simple rank-based sg layout creation (PR #172867)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 18 07:22:49 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Artem Kroviakov (akroviakov)

<details>
<summary>Changes</summary>

This PR adds sg_layout and sg_data creation when doing subgroup-level propagation.
The creation goals are: simple and sg_layout-stable (few layout conversions).

The logic is to decompose the number of subgroups into n prime factors, then combine the two smallest factors until we have `#rank` factors. It does not fit to inst_data size, it does not re-adjust layout to fit layout[dim] > shape[dim] cases, it does not guarantee the most even factor distribution in all cases. Any special cases that are not simply distributable require user-supplied layout. 

Additional complexity turns this short utility into a "solver" which is not the goal, at least not for now. 

---
Full diff: https://github.com/llvm/llvm-project/pull/172867.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+144-13) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+77) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 112d138c73e18..843c59bfe95da 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -538,6 +538,54 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   return false;
 }
 
+FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
+chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
+  const int rank = wgShape.size();
+
+  // Step 1. Factorize sgCount into `rank` factors
+  SmallVector<int> layout;
+  int64_t temp = sgCount;
+  for (int64_t i = 2; i * i <= temp; ++i) {
+    while (temp % i == 0) {
+      layout.push_back(i);
+      temp /= i;
+    }
+  }
+  if (temp > 1)
+    layout.push_back(temp);
+  if (layout.size() < rank)
+    return failure();
+
+  // Step 2. Fuse two smallest factors until we have `rank` factors.
+  while (layout.size() > rank) {
+    std::sort(layout.begin(), layout.end());
+    int64_t a = layout[0];
+    int64_t b = layout[1];
+    layout.erase(layout.begin());
+    layout[0] = a * b;
+  }
+
+  SmallVector<int> data;
+  for (auto [i, dim] : llvm::enumerate(layout)) {
+    if (wgShape[i] % dim != 0)
+      return failure();
+    data.push_back(wgShape[i] / dim);
+  }
+  return std::make_pair(layout, data);
+}
+
+FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
+  // Oblivious to workitem layout, the total count matters.
+  auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
+  if (!gpuFunc)
+    return failure();
+  auto knownBlockSize = gpuFunc.getKnownBlockSize();
+  if (!knownBlockSize.has_value())
+    return failure();
+  const int flatBlockSize = llvm::product_of(knownBlockSize.value());
+  return flatBlockSize / sgSize;
+}
+
 void LayoutInfoPropagation::visitPrefetchNdOp(
     xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
@@ -753,30 +801,89 @@ void LayoutInfoPropagation::visitDpasOp(
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
-    } else {
+    } else if (layoutKind == LayoutKind::Lane) {
       dpasALayout = getSIMTLayoutInfoForDPASOperand(
           aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
       dpasBLayout = getSIMTLayoutInfoForDPASOperand(
           bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+    } else {
+      auto numSgOrErr = getNumSg(dpas, subgroupSize);
+      if (failed(numSgOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
+      auto layoutDataAOrErr = chooseLayout(aTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataAOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine suitable subgroup layout and data for A.");
+        return;
+      }
+      auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+
+      dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+          DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+
+      auto layoutDataBOrErr = chooseLayout(bTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataBOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine suitable subgroup layout and data for B.");
+        return;
+      }
+      auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
+
+      dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+          DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
     }
 
     if (operands.size() > 2) {
       VectorType cTy = dpas.getAccType();
+      const unsigned dataCLen = bTy.getShape().back();
+      auto supportedCLen =
+          uArchInstruction->getSupportedN(bTy.getElementType());
+      const int maxCLen =
+          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+      if (maxCLen == -1) {
+        dpas.emitWarning(
+            "No suitable instruction multiple found for the given shape.");
+        return;
+      }
+      SmallVector<int> instDataCD = {maxALen, maxCLen};
       if (layoutKind == LayoutKind::InstData) {
-        const unsigned dataCLen = bTy.getShape().back();
-        auto supportedCLen =
-            uArchInstruction->getSupportedN(bTy.getElementType());
-        const int maxCLen = xegpu::getLargestDivisor(
-            dataCLen, ArrayRef<unsigned>(supportedCLen));
-        if (maxCLen == -1)
-          dpas.emitWarning(
-              "No suitable instruction multiple found for the given shape.");
-        SmallVector<int> instDataC = {maxALen, maxCLen};
         dpasCDLayout =
-            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
-      } else
+            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+      } else if (layoutKind == LayoutKind::Lane) {
         dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
             cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+      } else {
+        auto numSgOrErr = getNumSg(dpas, subgroupSize);
+        if (failed(numSgOrErr)) {
+          dpas.emitWarning(
+              "Unable to determine the number of subgroups for the operation.");
+          return;
+        }
+        auto layoutDataAOrErr =
+            chooseLayout(cTy.getShape(), numSgOrErr.value());
+        if (failed(layoutDataAOrErr)) {
+          dpas.emitWarning(
+              "Unable to determine suitable subgroup layout and data for A.");
+          return;
+        }
+        auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
+
+        dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+            cTy.getContext(),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+            /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+            /*lane_data =*/nullptr, /*order =*/nullptr));
+      }
 
       dpas.setLayoutCdAttr(
           dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
@@ -835,10 +942,34 @@ void LayoutInfoPropagation::visitStoreNdOp(
     if (layoutKind == LayoutKind::InstData)
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
-    else
+    else if (layoutKind == LayoutKind::Lane)
       storeLayout =
           getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
                                    uArchInstruction->getPackedFormatBitSize());
+    else { // LayoutKind::Subgroup
+      auto sgSize = uArch->getSubgroupSize();
+      auto numSgOrErr = getNumSg(store, sgSize);
+      if (failed(numSgOrErr)) {
+        store.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
+      auto layoutDataOrErr =
+          chooseLayout(dataTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataOrErr)) {
+        store.emitWarning(
+            "Unable to determine suitable subgroup layout and data.");
+        return;
+      }
+      auto [sgLayout, sgData] = layoutDataOrErr.value();
+
+      storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
+          dataTy.getContext(),
+          DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
+          DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+    }
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index c7dfc9fb7b1f1..a1f84a8fe57a2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -51,3 +51,80 @@ gpu.module @test {
     return
   }
 }
+
+// -----
+gpu.module @test {
+  // CHECK-LABEL: vector_transpose
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
+  gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 32, 16>} {
+    // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>>
+    // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>}>
+    // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>} :
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>> -> vector<256x128xf32>
+
+    // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
+    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>} : vector<256x128xf32> to vector<128x256xf32>
+
+    // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
+    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>}> : vector<128x256xf32>,
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+    %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
+    %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+    %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
+    xegpu.store_nd %trans, %tdesc1[0, 0] : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
+    gpu.return
+  }
+}
+
+// -----
+gpu.module @test {
+  // CHECK-LABEL: dpas
+  // CHECK-SAME: %[[A_MEMREF:.*]]: memref<128x128xf16>, %[[B_MEMREF:.*]]: memref<128x128xf16>
+  // CHECK-SAME: %[[CD_MEMREF:.*]]: memref<128x128xf32>
+  gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 64, 16>} {
+  // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+  // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+  // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+  // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+
+  // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[CD_MEMREF]] : memref<128x128xf32> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: xegpu.store_nd %[[DPAS_RES]], %[[TDESC_ST]][0, 0]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}> :
+  // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+    %load_a =  xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+    %load_b =  xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    %tdesc_cd = xegpu.create_nd_tdesc %d : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32>
+    xegpu.store_nd %dpas, %tdesc_cd[0, 0] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
+    gpu.return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/172867


More information about the Mlir-commits mailing list