[Mlir-commits] [mlir] [MLIR][XeGPU] Add anchor_layout and update propagation to honor user-specified layouts (PR #169267)
Jianhui Li
llvmlistbot at llvm.org
Wed Nov 26 09:23:47 PST 2025
================
@@ -475,48 +477,72 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+bool LayoutInfoPropagation::hasParamsOfLayoutKind(
+ xegpu::DistributeLayoutAttr anchorLayout) {
+ if (anchorLayout == nullptr) {
+ return false;
+ }
+ if (layoutKind == LayoutKind::InstData) {
+ return !(anchorLayout.getEffectiveInstDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Lane) {
+ return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveLaneDataAsInt().empty());
+ }
+ return false;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // Here we assign the default layout to the tensor descriptor operand of
- // prefetch.
- auto tdescTy = prefetch.getTensorDescType();
-
- auto uArch = getUArch(getChipStr(prefetch).value_or(""));
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
- uArch->getInstruction(
- xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
-
- auto blockWHC =
- uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
- if (!blockWHC)
- prefetch.emitWarning("No known block params found for the element type.");
- auto [bWidth, bHeight, bCount] = blockWHC.value();
- SmallVector<int> instData;
- int instWidth = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
- if (instWidth == -1)
- prefetch.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- if (tdescTy.getRank() == 1)
- instData = {instWidth};
- else {
- int instHeight = xegpu::getLargestDivisor(
- static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
- if (instHeight == -1)
+
+ LayoutInfo prefetchLayout;
+ xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr();
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ prefetchLayout = LayoutInfo(anchorLayout);
+ } else {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
+ bCount);
+ if (instWidth == -1)
prefetch.emitWarning(
"No suitable instruction multiple found for the given shape.");
- instData = {instHeight, instWidth};
- }
- LayoutInfo prefetchLayout;
- if (layoutKind == LayoutKind::InstData)
- prefetchLayout =
- LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
- else
- prefetchLayout = getDefaultSIMTLayoutInfo(
- tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = xegpu::getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+
+ if (layoutKind == LayoutKind::InstData)
+ prefetchLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+ else
+ prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+ prefetch.setLayoutAttr(
----------------
Jianhui-Li wrote:
The updateop() needs to be refactored. We should not update the "operand_result_*" attribute and hope it can be used by another pass. so we can try interface idea then.
https://github.com/llvm/llvm-project/pull/169267
More information about the Mlir-commits
mailing list