[Mlir-commits] [mlir] [MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (PR #175857)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 14:24:25 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/175857.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+34-47)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..48906be79a54b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
}
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
- const xegpu::uArch::uArch *uArch,
- unsigned packingSize,
- bool isScattered = false) {
+/// Helper to get the default layout for 2D block operations.
+template <typename Ty>
+static LayoutInfo
+getSIMTLayoutInfoFor2DBlockOp(Ty ty, const xegpu::uArch::uArch *uArch,
+ unsigned packingSize) {
// Expecting a 1D or 2D vector.
- assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+ assert((ty.getRank() == 1 || ty.getRank() == 2) &&
"Expected 1D or 2D vector.");
// Expecting int or float element type.
- assert(vectorTy.getElementType().isIntOrFloat() &&
+ assert(ty.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
- if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
+ if (ty.getRank() == 1)
+ return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
- unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
- if (isScattered) {
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {uArch->getSubgroupSize(), 1},
- {1, packingFactor}));
- }
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {1, uArch->getSubgroupSize()},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(
+ ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
- const xegpu::uArch::uArch *uArch,
- unsigned packingSize,
- bool isScattered = false) {
+static LayoutInfo
+getSIMTLayoutInfoForScatterIOVector(VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch,
+ unsigned packingSize) {
// Expecting a 1D or 2D vector.
- assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
- "Expected 1D or 2D TensorDesc.");
+ assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+ "Expected 1D or 2D vector.");
// Expecting int or float element type.
- assert(tdescTy.getElementType().isIntOrFloat() &&
+ assert(vectorTy.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
- if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
+ if (vectorTy.getRank() == 1)
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
- unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
- int subgroupSize = uArch->getSubgroupSize();
+ unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
- if (isScattered) {
- return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
- }
-
- return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+ {uArch->getSubgroupSize(), 1},
+ {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
+ return getSIMTLayoutInfoFor2DBlockOp(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
else
- prefetchLayout = getDefaultSIMTLayoutInfo(
+ prefetchLayout = getSIMTLayoutInfoFor2DBlockOp(
tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
else
- storeLayout =
- getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
- uArchInstruction->getPackedFormatBitSize());
+ storeLayout = getSIMTLayoutInfoFor2DBlockOp(
+ store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
store.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
loadLayout =
LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
else
- loadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
+ loadLayout = getSIMTLayoutInfoForScatterIOVector(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
// Mask operand should have 1D default layout.
maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
"Expected the first dimension of 2D tensor descriptor to be "
"equal to "
"subgroup size.");
- payloadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
- /*scattered=*/true);
+ payloadLayout = getSIMTLayoutInfoForScatterIOVector(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
}
maskLayout =
``````````
</details>
https://github.com/llvm/llvm-project/pull/175857
More information about the Mlir-commits
mailing list