[Mlir-commits] [mlir] 56d3169 - [MLIR][XeGPU] Reorganize uArch for easier extension (#178907)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 3 00:47:17 PST 2026
Author: Artem Kroviakov
Date: 2026-02-03T09:47:11+01:00
New Revision: 56d31697c2a2c38844bdc0f43e537c29a5115d87
URL: https://github.com/llvm/llvm-project/commit/56d31697c2a2c38844bdc0f43e537c29a5115d87
DIFF: https://github.com/llvm/llvm-project/commit/56d31697c2a2c38844bdc0f43e537c29a5115d87.diff
LOG: [MLIR][XeGPU] Reorganize uArch for easier extension (#178907)
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 29e75b57f4a5f..05b4dbdbb0317 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,26 +215,16 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
const unsigned packedFormatBitSizeB;
};
-struct StoreScatterInstruction : public Instruction {
- StoreScatterInstruction()
- : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::StoreScatter;
+struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
+ int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+ return 16;
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
-struct LoadGatherInstruction : public Instruction {
- LoadGatherInstruction()
- : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
- static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::LoadGather;
+struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
+ int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+ return 16;
}
-
- // SPIRV restricts vector size
- int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
//===----------------------------------------------------------------------===//
@@ -247,8 +237,8 @@ struct PVCuArch final : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const StoreScatterInstruction storeScatterInst;
- static const LoadGatherInstruction loadGatherInst;
+ static const SpirvStoreScatterInstruction storeScatterInst;
+ static const SpirvLoadGatherInstruction loadGatherInst;
static const Instruction *arr[] = {&dpasInst, &loadNdInst,
&storeNdInst, &prefetchNdInst,
&storeScatterInst, &loadGatherInst};
@@ -273,8 +263,8 @@ struct BMGuArch : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const StoreScatterInstruction storeScatterInst;
- static const LoadGatherInstruction loadGatherInst;
+ static const SpirvStoreScatterInstruction storeScatterInst;
+ static const SpirvLoadGatherInstruction loadGatherInst;
static const Instruction *arr[] = {&dpasInst, &loadNdInst,
&storeNdInst, &prefetchNdInst,
&storeScatterInst, &loadGatherInst};
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index db1984b2edb1d..ee3d5a5a8c398 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -250,6 +250,32 @@ struct MMAInstructionInterface {
virtual ~MMAInstructionInterface() = default;
};
+//===----------------------------------------------------------------------===//
+// Common instructions (shared across architectures)
+//===----------------------------------------------------------------------===//
+
+struct StoreScatterInstructionInterface : public Instruction {
+ StoreScatterInstructionInterface()
+ : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreScatter;
+ }
+
+ virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+ virtual ~StoreScatterInstructionInterface() = default;
+};
+
+struct LoadGatherInstructionInterface : public Instruction {
+ LoadGatherInstructionInterface()
+ : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadGather;
+ }
+
+ virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+ virtual ~LoadGatherInstructionInterface() = default;
+};
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 96fdced39d9ab..faafb7e8cee61 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1168,7 +1168,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
return;
}
const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+ dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
// Check if value inst_data complies with uArch
@@ -1181,9 +1181,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
load.emitWarning("Expected 2D payload for LoadGatherOp.");
return;
}
- instDataUarch.push_back(
- (std::min(static_cast<int>(payloadTy.getShape().back()),
- uArchInstruction->getMaxLaneLoadStoreSize())));
+ int elemBitWidth = payloadTy.getElementTypeBitWidth();
+ instDataUarch.push_back((
+ std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
}
// If inst data does not match, enforce the uArch-based one
if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1274,8 +1275,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
if (layoutKind == xegpu::LayoutKind::InstData) {
const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::StoreScatter));
+ dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::StoreScatter));
const int subgroupSize = uArch->getSubgroupSize();
SmallVector<int> instDataUarch{subgroupSize};
if (payloadTy.getRank() != 1) {
@@ -1283,9 +1285,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
return;
}
- instDataUarch.push_back(
- (std::min(static_cast<int>(payloadTy.getShape().back()),
- uArchInstruction->getMaxLaneLoadStoreSize())));
+ int elemBitWidth = payloadTy.getElementTypeBitWidth();
+ instDataUarch.push_back((
+ std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
}
payloadLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
More information about the Mlir-commits
mailing list