[Mlir-commits] [mlir] [MLIR][XeGPU] Add anchor_layout and update propagation to honor user-specified layouts (PR #169267)

Jianhui Li llvmlistbot at llvm.org
Wed Nov 26 09:23:47 PST 2025


================
@@ -475,48 +477,72 @@ LogicalResult LayoutInfoPropagation::visitOperation(
   return success();
 }
 
+bool LayoutInfoPropagation::hasParamsOfLayoutKind(
+    xegpu::DistributeLayoutAttr anchorLayout) {
+  if (anchorLayout == nullptr) {
+    return false;
+  }
+  if (layoutKind == LayoutKind::InstData) {
+    return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+  } else if (layoutKind == LayoutKind::Lane) {
+    return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+             anchorLayout.getEffectiveLaneDataAsInt().empty());
+  }
+  return false;
+}
+
 void LayoutInfoPropagation::visitPrefetchNdOp(
     xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  // Here we assign the default layout to the tensor descriptor operand of
-  // prefetch.
-  auto tdescTy = prefetch.getTensorDescType();
-
-  auto uArch = getUArch(getChipStr(prefetch).value_or(""));
-  const auto *uArchInstruction =
-      dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
-          uArch->getInstruction(
-              xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
-
-  auto blockWHC =
-      uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
-  if (!blockWHC)
-    prefetch.emitWarning("No known block params found for the element type.");
-  auto [bWidth, bHeight, bCount] = blockWHC.value();
-  SmallVector<int> instData;
-  int instWidth = xegpu::getLargestDivisor(
-      static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
-  if (instWidth == -1)
-    prefetch.emitWarning(
-        "No suitable instruction multiple found for the given shape.");
-  if (tdescTy.getRank() == 1)
-    instData = {instWidth};
-  else {
-    int instHeight = xegpu::getLargestDivisor(
-        static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
-    if (instHeight == -1)
+
+  LayoutInfo prefetchLayout;
+  xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
+  if (hasParamsOfLayoutKind(anchorLayout)) {
+    prefetchLayout = LayoutInfo(anchorLayout);
+  } else {
+    // Here we assign the default layout to the tensor descriptor operand of
+    // prefetch.
+    auto tdescTy = prefetch.getTensorDescType();
+
+    auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+            uArch->getInstruction(
+                xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+    auto blockWHC =
+        uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+    if (!blockWHC)
+      prefetch.emitWarning("No known block params found for the element type.");
+    auto [bWidth, bHeight, bCount] = blockWHC.value();
+    SmallVector<int> instData;
+    int instWidth = xegpu::getLargestDivisor(
+        static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
+        bCount);
+    if (instWidth == -1)
       prefetch.emitWarning(
           "No suitable instruction multiple found for the given shape.");
-    instData = {instHeight, instWidth};
-  }
-  LayoutInfo prefetchLayout;
-  if (layoutKind == LayoutKind::InstData)
-    prefetchLayout =
-        LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
-  else
-    prefetchLayout = getDefaultSIMTLayoutInfo(
-        tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+    if (tdescTy.getRank() == 1)
+      instData = {instWidth};
+    else {
+      int instHeight = xegpu::getLargestDivisor(
+          static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+      if (instHeight == -1)
+        prefetch.emitWarning(
+            "No suitable instruction multiple found for the given shape.");
+      instData = {instHeight, instWidth};
+    }
+
+    if (layoutKind == LayoutKind::InstData)
+      prefetchLayout =
+          LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+    else
+      prefetchLayout = getDefaultSIMTLayoutInfo(
+          tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
+    prefetch.setLayoutAttr(
----------------
Jianhui-Li wrote:

The updateop() needs to be refactored. We should not update the "operand_result_*" attribute and hope it can be used by another pass.  so we can try interface idea then. 

https://github.com/llvm/llvm-project/pull/169267


More information about the Mlir-commits mailing list