[Mlir-commits] [mlir] [MLIR][XeGPU] Add uArch limitation to scatter load store (PR #172845)

Jianhui Li llvmlistbot at llvm.org
Wed Jan 14 22:30:10 PST 2026


================
@@ -973,42 +973,75 @@ void LayoutInfoPropagation::visitLoadGatherOp(
 
   LayoutInfo loadLayout;
   LayoutInfo maskLayout;
+  auto uArch = getUArch(getChipStr(load).value_or(""));
+  const int subgroupSize = uArch->getSubgroupSize();
   xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
   if (hasParamsOfLayoutKind(anchorLayout)) {
     loadLayout = LayoutInfo(anchorLayout);
-    maskLayout = loadLayout;
   } else {
+    LayoutInfo valueLayout = results[0]->getValue();
+    // Need the layout of the value to propagate to the tensor descriptor.
+    if (!valueLayout.isAssigned())
+      return;
+
+    auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
+    auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
+    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+      instDataIncoming = SmallVector<int64_t>(
+          cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
+              .getInstData()
+              .asArrayRef());
 
-    // The layout is strictly determined by the payload type.
     VectorType payloadTy = load.getValueType();
     if (!payloadTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-    auto uArch = getUArch(getChipStr(load).value_or(""));
-    const int subgroupSize = uArch->getSubgroupSize();
-    SmallVector<int> instData{subgroupSize};
-    if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
-      instData.push_back(chunkSize);
-    else if (auto srcTdescTy =
-                 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
-      if (srcTdescTy.getChunkSizeAsInt() > 1)
-        instData.push_back(chunkSize);
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+            uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+
+    // Check if value inst_data complies with uArch
+    if (!instDataIncoming.empty()) {
+      const int maxElemsPerInst =
+          uArchInstruction->getMaxBitSize() /
+          payloadTy.getElementType().getIntOrFloatBitWidth();
+
+      // Each lane loads either one element
+      SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
+      // Or multiple elements as 2D with lane's elements in the inner dimension
+      if (payloadTy.getRank() == 1) {
+        instDataUarch.back() = subgroupSize;
+      } else {
+        *std::prev(instDataUarch.end(), 2) = subgroupSize;
----------------
Jianhui-Li wrote:

The load op has hardware restriction so the inst_data layout should only support either 1d or 2d, regardless how it is used (propagated to load's result).  I think we should limit the instDataUarch's rank to 2d and below, regardless the propagated layout ( instDataIncoming). 

Once apply this restriction (2d or 1d only), here you can just use Charitha's suggestion "instDataUarch[0] = " since this bracket is only about 2d. 

https://github.com/llvm/llvm-project/pull/172845


More information about the Mlir-commits mailing list