[Mlir-commits] [mlir] [MLIR][XeGPU] Improve `xegpu::uArch` design (PR #163986)
Artem Kroviakov
llvmlistbot at llvm.org
Sun Oct 26 02:27:35 PDT 2025
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/163986
>From 416eee13a5b76e69cce3d4f0f7436c99719ee5c8 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 17 Oct 2025 16:22:34 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU] Improve uArch design
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 205 +++++++++++-------
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 99 ++++-----
2 files changed, 173 insertions(+), 131 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..f264be5181b2a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -23,8 +23,6 @@
#include <map>
#include <string>
-#define DEBUG_TYPE "xegpu-uarch"
-
using namespace mlir;
using namespace mlir::xegpu::uArch;
@@ -33,21 +31,80 @@ namespace xegpu {
namespace uArch {
struct Xe2Plus : public uArch {
+ Xe2Plus(StringRef archName, StringRef archDescription,
+ llvm::ArrayRef<const Instruction *> instructionRegistry,
+ const XeCoreInfo &xeCore)
+ : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
+ int getSubgroupSize() const override { return 16; }
+ unsigned getPackedFormatBitSize() const override { return 16; }
+ unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
+
+protected:
XeCoreInfo xeCore;
- Xe2Plus(const std::string &archName, const std::string &archDescription,
- const XeCoreInfo &xeCore,
- const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instrs = {})
- : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
- xeCore(xeCore) {}
};
-// struct to represent DPAS instruction
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+ StoreNdInstruction()
+ : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+ // the specified pointer
+ llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+ const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+ return sortedLaneVectorLengths;
+ }
+};
+
+struct LoadNdInstruction : public Instruction {
+ LoadNdInstruction()
+ : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+ // the specified pointer.
+ llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+ const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+ return sortedLaneVectorLengths;
+ }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+ PrefetchNdInstruction()
+ : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+ llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
+ const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
+ const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
+ switch (elementBitwidth) {
+ case 8:
+ case 16:
+ return sortedNarrowTypesLengths;
+ case 32:
+ case 64:
+ return sortedWideTypesLengths;
+ default:
+ llvm_unreachable("Unsupported element bitwidth");
+ }
+ }
+};
+
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
- DPASInstruction()
- : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+ DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
+ unsigned packedFormatBitSizeC)
+ : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
+ packedFormatBitSizeA(packedFormatBitSizeA),
+ packedFormatBitSizeB(packedFormatBitSizeB),
+ packedFormatBitSizeC(packedFormatBitSizeC) {}
+ // Source:
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -67,82 +124,82 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedM(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedK(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedN(Type type) const override;
+
+ unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
+ unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+ unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
+
+protected:
+ const unsigned packedFormatBitSizeA;
+ const unsigned packedFormatBitSizeB;
+ const unsigned packedFormatBitSizeC;
};
-struct PVCuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+//===----------------------------------------------------------------------===//
+// uArch instances
+//===----------------------------------------------------------------------===//
+
+struct PVCuArch final : public Xe2Plus {
+ inline static const DPASInstruction dpasInst{16, 32, 32};
+ inline static const StoreNdInstruction loadNdInst;
+ inline static const StoreNdInstruction storeNdInst;
+ inline static const PrefetchNdInstruction prefetchNdInst;
+ inline static const Instruction *const instructionRegistryArr[] = {
+ &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo.emplace(
- RegisterFileType::GRF,
- RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- ));
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
-
- // Add the instructions-
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ instructionRegistryArr,
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const PVCuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
struct BMGuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+ inline static const DPASInstruction dpasInst{16, 32, 32};
+ inline static const StoreNdInstruction loadNdInst;
+ inline static const StoreNdInstruction storeNdInst;
+ inline static const PrefetchNdInstruction prefetchNdInst;
+ inline static const Instruction *const instructionRegistryArr[] = {
+ &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- );
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
-
- // Add the instructions
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ instructionRegistryArr,
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const BMGuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
+
+inline const uArch *getUArch(llvm::StringRef archName) {
+ if (archName.equals_insensitive("pvc"))
+ return PVCuArch::getInstance();
+ else if (archName.equals_insensitive("bmg"))
+ return BMGuArch::getInstance();
+
+ return nullptr;
+}
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
@@ -257,12 +314,12 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) {
+DPASInstruction::getSupportedM(Type type) const {
return {1, 2, 3, 4, 5, 6, 7, 8};
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) {
+DPASInstruction::getSupportedK(Type type) const {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -290,7 +347,7 @@ DPASInstruction::getSupportedK(Type type) {
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) {
+DPASInstruction::getSupportedN(Type type) const {
return {16};
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..f5844d62f374c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
- // multiply-add operation
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+ // multiply-add operation
+ STORE_ND, // Subgroup-level 2D block write instruction
+ LOAD_ND, // Subgroup-level 2D block load instruction
+ PREFETCH_ND // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};
@@ -48,12 +51,18 @@ struct Instruction {
virtual ~Instruction() = default;
// Get methods
- InstructionKind getInstructionKind() { return instKind; }
- InstructionScope getScope() { return scope; }
+ InstructionKind getInstructionKind() const { return instKind; }
+ InstructionScope getScope() const { return scope; }
static llvm::StringRef toString(InstructionKind instKind) {
switch (instKind) {
case InstructionKind::DPAS:
return "dpas";
+ case InstructionKind::STORE_ND:
+ return "store_nd";
+ case InstructionKind::LOAD_ND:
+ return "load_nd";
+ case InstructionKind::PREFETCH_ND:
+ return "prefetch_nd";
}
llvm_unreachable("Unknown InstructionKind");
}
@@ -66,9 +75,9 @@ struct Instruction {
}
protected:
- InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
- InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
- // workgroup, cluster)
+ const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+ const InstructionScope scope; // scope of the instruction (e.g., lane,
+ // subgroup, workgroup, cluster)
// @TODO: Add more fields as needed
};
@@ -129,61 +138,37 @@ struct CacheInfo {
// latency, throughput, bandwidth)
};
-// A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device.
struct uArch {
// Constructor
- uArch(
- const std::string &name, const std::string &description,
- const std::map<RegisterFileType, RegisterFileInfo> ®isterFileInfo = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instructions = {})
- : name(name), description(description),
- registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
- instructions(instructions) {}
-
- // Get methods
- const std::string &getName() const { return name; }
-
- const std::string &getDescription() const { return description; }
-
- const std::map<RegisterFileType, RegisterFileInfo> &
- getRegisterFileInfo() const {
- return registerFileInfo;
- }
-
- const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
- return cacheInfo;
- }
-
- const std::map<InstructionKind, std::shared_ptr<Instruction>> &
- getInstructions() const {
- return instructions;
+ uArch(StringRef name, StringRef description,
+ llvm::ArrayRef<const Instruction *> instructionRegistry)
+ : name(name), description(description) {
+ for (const Instruction *instr : instructionRegistry)
+ this->instructionRegistry[instr->getInstructionKind()] = instr;
}
-
- // Get the name of the supported instruction names for that
- // architecture. It returns the names of the instructions added to the uArch.
- llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
- llvm::SmallVector<StringRef, 8> instructionNames;
- for (const auto &inst : instructions) {
- instructionNames.push_back(Instruction::toString(inst.first));
- }
- return instructionNames;
+ virtual ~uArch() = default;
+ StringRef getName() const { return name; }
+ StringRef getDescription() const { return description; }
+ virtual int getSubgroupSize() const = 0;
+ virtual unsigned getPackedFormatBitSize() const = 0;
+ virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
+
+ const Instruction *getInstruction(InstructionKind instKind) const {
+ auto it = instructionRegistry.find(instKind);
+ assert(it != instructionRegistry.end() &&
+ "Instruction not found in registry");
+ return it->second;
}
- // Checks if an instruction is supported in this uArch
- bool checkSupportedInstruction(InstructionKind instr) const {
- return instructions.find(instr) != instructions.end();
+ bool isSupportedInstruction(InstructionKind instr) const {
+ return instructionRegistry.contains(instr);
}
protected:
- std::string name; // Name of the uArch, similar to target triple
- std::string description;
- std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
- llvm::SmallVector<CacheInfo, 4> cacheInfo;
- std::map<InstructionKind, std::shared_ptr<Instruction>>
- instructions; // set of instructions supported by the uArch
+ StringRef name;
+ StringRef description;
+ llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
+ instructionRegistry;
};
// A struct to represent shared memory information
@@ -251,9 +236,9 @@ struct MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) const = 0;
virtual ~MMAInstructionInterface() = default;
};
>From 92c80cf7c9f995b7e9bd424304c94e6e1bff3453 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 17 Oct 2025 16:34:50 +0000
Subject: [PATCH 2/5] Fix warning
---
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index f5844d62f374c..82a5223c43651 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -49,7 +49,7 @@ struct Instruction {
Instruction(InstructionKind kind, InstructionScope scope)
: instKind(kind), scope(scope) {}
- virtual ~Instruction() = default;
+ ~Instruction() = default;
// Get methods
InstructionKind getInstructionKind() const { return instKind; }
InstructionScope getScope() const { return scope; }
>From 1bf2b652676a702aeb81dcc6f0443096bc834ee2 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 22 Oct 2025 12:59:53 +0000
Subject: [PATCH 3/5] Fix warnings
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 51 ++++++++++++-------
1 file changed, 34 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index f264be5181b2a..f00644b9dea09 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -49,7 +49,9 @@ struct Xe2Plus : public uArch {
struct StoreNdInstruction : public Instruction {
StoreNdInstruction()
: Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
-
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::STORE_ND;
+ }
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
@@ -63,7 +65,9 @@ struct StoreNdInstruction : public Instruction {
struct LoadNdInstruction : public Instruction {
LoadNdInstruction()
: Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
-
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LOAD_ND;
+ }
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
@@ -77,7 +81,9 @@ struct LoadNdInstruction : public Instruction {
struct PrefetchNdInstruction : public Instruction {
PrefetchNdInstruction()
: Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
-
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::PREFETCH_ND;
+ }
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
@@ -103,6 +109,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
packedFormatBitSizeA(packedFormatBitSizeA),
packedFormatBitSizeB(packedFormatBitSizeB),
packedFormatBitSizeC(packedFormatBitSizeC) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::DPAS;
+ }
// Source:
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
@@ -146,16 +155,20 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
//===----------------------------------------------------------------------===//
struct PVCuArch final : public Xe2Plus {
- inline static const DPASInstruction dpasInst{16, 32, 32};
- inline static const StoreNdInstruction loadNdInst;
- inline static const StoreNdInstruction storeNdInst;
- inline static const PrefetchNdInstruction prefetchNdInst;
- inline static const Instruction *const instructionRegistryArr[] = {
- &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
+ static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+ static const DPASInstruction dpasInst{16, 32, 32};
+ static const StoreNdInstruction loadNdInst;
+ static const StoreNdInstruction storeNdInst;
+ static const PrefetchNdInstruction prefetchNdInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+ &prefetchNdInst};
+ return arr;
+ }
+
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
- instructionRegistryArr,
+ getInstructionRegistryArr(),
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
) {}
static const uArch *getInstance() {
@@ -165,16 +178,20 @@ struct PVCuArch final : public Xe2Plus {
};
struct BMGuArch : public Xe2Plus {
- inline static const DPASInstruction dpasInst{16, 32, 32};
- inline static const StoreNdInstruction loadNdInst;
- inline static const StoreNdInstruction storeNdInst;
- inline static const PrefetchNdInstruction prefetchNdInst;
- inline static const Instruction *const instructionRegistryArr[] = {
- &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
+ static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+ static const DPASInstruction dpasInst{16, 32, 32};
+ static const StoreNdInstruction loadNdInst;
+ static const StoreNdInstruction storeNdInst;
+ static const PrefetchNdInstruction prefetchNdInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+ &prefetchNdInst};
+ return arr;
+ }
+
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
- instructionRegistryArr,
+ getInstructionRegistryArr(),
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
) {}
static const uArch *getInstance() {
>From e834a9fb0690c4c25e1c2d7e3b4f08df166eec6a Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 26 Oct 2025 08:54:49 +0000
Subject: [PATCH 4/5] Update instructions
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 116 +++++++++++++-----
1 file changed, 87 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index f00644b9dea09..7e9e714c6007a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -53,12 +53,26 @@ struct StoreNdInstruction : public Instruction {
return B->getInstructionKind() == InstructionKind::STORE_ND;
}
// Source :
- // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
// the specified pointer
- llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
- const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
- return sortedLaneVectorLengths;
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy) const {
+ const static int kHeight[] = {1, 2, 4, 8};
+ const static int kWidth16[] = {16};
+ const static int kWidth32[] = {16};
+ const static int kCount[] = {1};
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ if (elemByteSize == 1)
+ return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
+ llvm::ArrayRef<int>(kHeight),
+ llvm::ArrayRef<int>(kCount));
+ else if (elemByteSize == 2 || elemByteSize == 4)
+ return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
+ llvm::ArrayRef<int>(kHeight),
+ llvm::ArrayRef<int>(kCount));
+ return std::nullopt;
}
};
@@ -68,13 +82,49 @@ struct LoadNdInstruction : public Instruction {
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::LOAD_ND;
}
+
// Source :
- // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
// the specified pointer.
- llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
- const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
- return sortedLaneVectorLengths;
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
+ bool upConv = false) const {
+ static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+ static const int kHeightAtLeast8[] = {8, 16, 32};
+ static const int kHeightAtLeast16[] = {16, 32};
+ static const int kHeightAtLeast32[] = {32};
+
+ static const int kWidth32[] = {32};
+ static const int kWidth16[] = {16};
+ static const int kWidth8[] = {8};
+
+ static const int32_t kCount1[] = {1};
+ static const int32_t kCount2[] = {1, 2};
+ static const int32_t kCount4[] = {1, 2, 4};
+ static const int32_t kCount4Only[] = {4};
+ // (elemBytes, transform, transpose, upConvert)
+ using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
+ // (widths, heights, counts)
+ using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+ llvm::ArrayRef<int32_t>>;
+ static const llvm::DenseMap<Key, Value> kMap = {
+ {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
+ {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
+ {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
+ {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
+ // Block Loads with Transform:
+ {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
+ {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
+ // Block Loads with Transpose:
+ {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
+ };
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
+ if (it != kMap.end())
+ return it->second;
+ return std::nullopt;
}
};
@@ -86,29 +136,39 @@ struct PrefetchNdInstruction : public Instruction {
}
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
- llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
- const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
- const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
- switch (elementBitwidth) {
- case 8:
- case 16:
- return sortedNarrowTypesLengths;
- case 32:
- case 64:
- return sortedWideTypesLengths;
- default:
- llvm_unreachable("Unsupported element bitwidth");
- }
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy) const {
+ static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+
+ static const int kWidth32[] = {32};
+ static const int kWidth16[] = {16};
+
+ static const int32_t kCount1[] = {1};
+ static const int32_t kCount2[] = {1, 2};
+ // elemBytes
+ using Key = int;
+ // (widths, heights, counts)
+ using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+ llvm::ArrayRef<int32_t>>;
+ static const llvm::DenseMap<Key, Value> kMap = {
+ {1, {kWidth32, kHeightAtLeast1, kCount2}},
+ {2, {kWidth16, kHeightAtLeast1, kCount2}},
+ {4, {kWidth16, kHeightAtLeast1, kCount1}},
+ };
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ auto it = kMap.find(elemByteSize);
+ if (it != kMap.end())
+ return it->second;
+ return std::nullopt;
}
};
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
- DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
- unsigned packedFormatBitSizeC)
+ DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
packedFormatBitSizeA(packedFormatBitSizeA),
- packedFormatBitSizeB(packedFormatBitSizeB),
- packedFormatBitSizeC(packedFormatBitSizeC) {}
+ packedFormatBitSizeB(packedFormatBitSizeB) {}
static bool classof(const Instruction *B) {
return B->getInstructionKind() == InstructionKind::DPAS;
}
@@ -142,12 +202,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
- unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
protected:
const unsigned packedFormatBitSizeA;
const unsigned packedFormatBitSizeB;
- const unsigned packedFormatBitSizeC;
};
//===----------------------------------------------------------------------===//
@@ -156,7 +214,7 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
struct PVCuArch final : public Xe2Plus {
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
- static const DPASInstruction dpasInst{16, 32, 32};
+ static const DPASInstruction dpasInst{16, 32};
static const StoreNdInstruction loadNdInst;
static const StoreNdInstruction storeNdInst;
static const PrefetchNdInstruction prefetchNdInst;
@@ -179,7 +237,7 @@ struct PVCuArch final : public Xe2Plus {
struct BMGuArch : public Xe2Plus {
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
- static const DPASInstruction dpasInst{16, 32, 32};
+ static const DPASInstruction dpasInst{16, 32};
static const StoreNdInstruction loadNdInst;
static const StoreNdInstruction storeNdInst;
static const PrefetchNdInstruction prefetchNdInst;
>From 32a2531b73a6e9f79693c913b444e51040ab8e38 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 26 Oct 2025 09:27:18 +0000
Subject: [PATCH 5/5] Address feedback
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 98 ++++++++++---------
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 23 +++--
2 files changed, 64 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 7e9e714c6007a..dcb2ad5d67a25 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -36,8 +36,7 @@ struct Xe2Plus : public uArch {
const XeCoreInfo &xeCore)
: uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
- unsigned getPackedFormatBitSize() const override { return 16; }
- unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
+ unsigned getGeneralPackedFormatBitSize() const override { return 32; }
protected:
XeCoreInfo xeCore;
@@ -46,16 +45,15 @@ struct Xe2Plus : public uArch {
//===----------------------------------------------------------------------===//
// uArch instructions
//===----------------------------------------------------------------------===//
-struct StoreNdInstruction : public Instruction {
- StoreNdInstruction()
- : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockStoreInstruction : public Instruction {
+ Subgroup2DBlockStoreInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockStore,
+ InstructionScope::Subgroup) {}
static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::STORE_ND;
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
}
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
- // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
- // the specified pointer
std::optional<
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
getBlockWidthHeightCount(Type elemTy) const {
@@ -74,19 +72,20 @@ struct StoreNdInstruction : public Instruction {
llvm::ArrayRef<int>(kCount));
return std::nullopt;
}
+
+ int32_t getPackedFormatBitSize() const { return 16; }
};
-struct LoadNdInstruction : public Instruction {
- LoadNdInstruction()
- : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockLoadInstruction : public Instruction {
+ Subgroup2DBlockLoadInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockLoad,
+ InstructionScope::Subgroup) {}
static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::LOAD_ND;
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
}
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
- // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
- // the specified pointer.
std::optional<
std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
@@ -126,13 +125,16 @@ struct LoadNdInstruction : public Instruction {
return it->second;
return std::nullopt;
}
+
+ int32_t getPackedFormatBitSize() const { return 16; }
};
-struct PrefetchNdInstruction : public Instruction {
- PrefetchNdInstruction()
- : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockPrefetchInstruction : public Instruction {
+ Subgroup2DBlockPrefetchInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockPrefetch,
+ InstructionScope::Subgroup) {}
static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::PREFETCH_ND;
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
}
// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
@@ -162,15 +164,20 @@ struct PrefetchNdInstruction : public Instruction {
return it->second;
return std::nullopt;
}
+ int32_t getPackedFormatBitSize() const { return 16; }
};
-struct DPASInstruction : public Instruction, public MMAInstructionInterface {
- DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
- : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
+struct SubgroupMatrixMultiplyAcc : public Instruction,
+ public MMAInstructionInterface {
+ SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
+ unsigned packedFormatBitSizeB)
+ : Instruction(InstructionKind::SubgroupMatrixMultiplyAcc,
+ InstructionScope::Subgroup),
packedFormatBitSizeA(packedFormatBitSizeA),
packedFormatBitSizeB(packedFormatBitSizeB) {}
static bool classof(const Instruction *B) {
- return B->getInstructionKind() == InstructionKind::DPAS;
+ return B->getInstructionKind() ==
+ InstructionKind::SubgroupMatrixMultiplyAcc;
}
// Source:
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
@@ -214,10 +221,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
struct PVCuArch final : public Xe2Plus {
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
- static const DPASInstruction dpasInst{16, 32};
- static const StoreNdInstruction loadNdInst;
- static const StoreNdInstruction storeNdInst;
- static const PrefetchNdInstruction prefetchNdInst;
+ static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+ static const Subgroup2DBlockLoadInstruction loadNdInst;
+ static const Subgroup2DBlockStoreInstruction storeNdInst;
+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
&prefetchNdInst};
return arr;
@@ -237,10 +244,10 @@ struct PVCuArch final : public Xe2Plus {
struct BMGuArch : public Xe2Plus {
static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
- static const DPASInstruction dpasInst{16, 32};
- static const StoreNdInstruction loadNdInst;
- static const StoreNdInstruction storeNdInst;
- static const PrefetchNdInstruction prefetchNdInst;
+ static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+ static const Subgroup2DBlockLoadInstruction loadNdInst;
+ static const Subgroup2DBlockStoreInstruction storeNdInst;
+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
&prefetchNdInst};
return arr;
@@ -276,7 +283,8 @@ inline const uArch *getUArch(llvm::StringRef archName) {
//===----------------------------------------------------------------------===//
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
-DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedShapes(Type dataType,
+ MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
const llvm::SmallVector<uint32_t, 8> &b)
-> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
@@ -312,8 +320,8 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
}
inline llvm::SmallVector<Type, 8>
-DPASInstruction::getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) {
Type bf16Type = BFloat16Type::get(&context);
Type f16Type = Float16Type::get(&context);
Type tf32Type = FloatTF32Type::get(&context);
@@ -332,8 +340,10 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
return {};
}
-inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
- Type CType, Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedTypes(Type AType,
+ Type BType,
+ Type CType,
+ Type DType) {
if (AType.isF16() || BType.isF16()) {
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
(!DType.isF32() && !DType.isF16())) {
@@ -363,7 +373,7 @@ inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
return true;
}
-inline bool DPASInstruction::checkSupportedShapesAndTypes(
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedShapesAndTypes(
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
Type AType, Type BType, Type CType, Type DType) {
@@ -378,23 +388,21 @@ inline bool DPASInstruction::checkSupportedShapesAndTypes(
checkSupportedTypes(AType, BType, CType, DType);
}
-inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
- std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape,
- Type AType, Type BType, Type CType,
- Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::validate(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType, Type DType) {
return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
BType, CType, DType);
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedM(Type type) const {
return {1, 2, 3, 4, 5, 6, 7, 8};
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedK(Type type) const {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -422,7 +430,7 @@ DPASInstruction::getSupportedK(Type type) const {
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedN(Type type) const {
return {16};
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 82a5223c43651..ea33e885c78ff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,11 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
- // multiply-add operation
- STORE_ND, // Subgroup-level 2D block write instruction
- LOAD_ND, // Subgroup-level 2D block load instruction
- PREFETCH_ND // Subgroup-level 2D block prefetch instruction
+ SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
+ // matrix multiply-add operation
+ Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
+ Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
+ Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};
@@ -55,13 +55,13 @@ struct Instruction {
InstructionScope getScope() const { return scope; }
static llvm::StringRef toString(InstructionKind instKind) {
switch (instKind) {
- case InstructionKind::DPAS:
+ case InstructionKind::SubgroupMatrixMultiplyAcc:
return "dpas";
- case InstructionKind::STORE_ND:
+ case InstructionKind::Subgroup2DBlockStore:
return "store_nd";
- case InstructionKind::LOAD_ND:
+ case InstructionKind::Subgroup2DBlockLoad:
return "load_nd";
- case InstructionKind::PREFETCH_ND:
+ case InstructionKind::Subgroup2DBlockPrefetch:
return "prefetch_nd";
}
llvm_unreachable("Unknown InstructionKind");
@@ -70,7 +70,7 @@ struct Instruction {
static std::optional<InstructionKind>
parseInstructionKind(llvm::StringRef str) {
if (str.equals_insensitive("dpas"))
- return InstructionKind::DPAS;
+ return InstructionKind::SubgroupMatrixMultiplyAcc;
return std::nullopt;
}
@@ -150,8 +150,7 @@ struct uArch {
StringRef getName() const { return name; }
StringRef getDescription() const { return description; }
virtual int getSubgroupSize() const = 0;
- virtual unsigned getPackedFormatBitSize() const = 0;
- virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
+ virtual unsigned getGeneralPackedFormatBitSize() const = 0;
const Instruction *getInstruction(InstructionKind instKind) const {
auto it = instructionRegistry.find(instKind);
More information about the Mlir-commits
mailing list