[Mlir-commits] [mlir] [MLIR][XeGPU] Introduce `xegpu::uArch` usage in target-sensitive passes (PR #163801)
Charitha Saumya
llvmlistbot at llvm.org
Mon Oct 27 15:05:58 PDT 2025
================
@@ -557,23 +599,54 @@ void LayoutInfoPropagation::visitDpasOp(
ArrayRef<const LayoutInfoLattice *> results) {
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
- propagateIfChanged(
- operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
- propagateIfChanged(
- operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ auto uArchInstruction =
+ std::static_pointer_cast<xegpu::uArch::DPASInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::DPAS));
+ const int maxALen =
+ uArchInstruction->getSupportedM(aTy.getElementType()).back();
+ const int maxBLen =
+ uArchInstruction->getSupportedK(bTy.getElementType()).back();
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ propagateIfChanged(operands[0],
+ operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, instDataA)));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, instDataB)));
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
- propagateIfChanged(
- operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
+ const int maxCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType()).back();
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ propagateIfChanged(operands[2],
+ operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, instDataC)));
}
}
/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
+
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ int subgroupSize = uArch->getSubgroupSize();
+ auto uArchInstruction =
+ std::static_pointer_cast<xegpu::uArch::StoreNdInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::STORE_ND));
+ int maxVecLength = uArchInstruction->getSortedLaneVectorLengths().back();
+ SmallVector<int> instData;
+ if (store.getValueType().getRank() == 1)
+ instData = {subgroupSize};
+ else
+ instData = {maxVecLength, subgroupSize};
----------------
charithaintc wrote:
same comment as above. need to check against the provided store shape.
https://github.com/llvm/llvm-project/pull/163801
More information about the Mlir-commits
mailing list