[Mlir-commits] [mlir] [MLIR][XeGPU] Improve `xegpu::uArch` design (PR #163986)

Artem Kroviakov llvmlistbot at llvm.org
Sun Oct 26 02:27:35 PDT 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/163986

>From 416eee13a5b76e69cce3d4f0f7436c99719ee5c8 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 17 Oct 2025 16:22:34 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU] Improve uArch design

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 205 +++++++++++-------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |  99 ++++-----
 2 files changed, 173 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..f264be5181b2a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -23,8 +23,6 @@
 #include <map>
 #include <string>
 
-#define DEBUG_TYPE "xegpu-uarch"
-
 using namespace mlir;
 using namespace mlir::xegpu::uArch;
 
@@ -33,21 +31,80 @@ namespace xegpu {
 namespace uArch {
 
 struct Xe2Plus : public uArch {
+  Xe2Plus(StringRef archName, StringRef archDescription,
+          llvm::ArrayRef<const Instruction *> instructionRegistry,
+          const XeCoreInfo &xeCore)
+      : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
+  int getSubgroupSize() const override { return 16; }
+  unsigned getPackedFormatBitSize() const override { return 16; }
+  unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
+
+protected:
   XeCoreInfo xeCore;
-  Xe2Plus(const std::string &archName, const std::string &archDescription,
-          const XeCoreInfo &xeCore,
-          const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
-          const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
-          const std::map<InstructionKind, std::shared_ptr<Instruction>>
-              &instrs = {})
-      : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
-        xeCore(xeCore) {}
 };
 
-// struct to represent DPAS instruction
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+  StoreNdInstruction()
+      : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+  // the specified pointer
+  llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+    const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+    return sortedLaneVectorLengths;
+  }
+};
+
+struct LoadNdInstruction : public Instruction {
+  LoadNdInstruction()
+      : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+  // the specified pointer.
+  llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
+    const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
+    return sortedLaneVectorLengths;
+  }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+  PrefetchNdInstruction()
+      : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+  llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
+    const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
+    const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
+    switch (elementBitwidth) {
+    case 8:
+    case 16:
+      return sortedNarrowTypesLengths;
+    case 32:
+    case 64:
+      return sortedWideTypesLengths;
+    default:
+      llvm_unreachable("Unsupported element bitwidth");
+    }
+  }
+};
+
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
-  DPASInstruction()
-      : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+  DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
+                  unsigned packedFormatBitSizeC)
+      : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
+        packedFormatBitSizeA(packedFormatBitSizeA),
+        packedFormatBitSizeB(packedFormatBitSizeB),
+        packedFormatBitSizeC(packedFormatBitSizeC) {}
+  // Source:
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
 
   // Override all virtuals from MatrixOpInterface
   virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -67,82 +124,82 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
                         std::pair<uint32_t, uint32_t> CShape,
                         std::pair<uint32_t, uint32_t> DShape, Type AType,
                         Type BType, Type CType, Type DType) override;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedM(Type type) const override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedK(Type type) const override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedN(Type type) const override;
+
+  unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
+  unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+  unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
+
+protected:
+  const unsigned packedFormatBitSizeA;
+  const unsigned packedFormatBitSizeB;
+  const unsigned packedFormatBitSizeC;
 };
 
-struct PVCuArch : public Xe2Plus {
-  // Maintaines ownership of the instructions owned by PVUarch
-  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+//===----------------------------------------------------------------------===//
+// uArch instances
+//===----------------------------------------------------------------------===//
+
+struct PVCuArch final : public Xe2Plus {
+  inline static const DPASInstruction dpasInst{16, 32, 32};
+  inline static const StoreNdInstruction loadNdInst;
+  inline static const StoreNdInstruction storeNdInst;
+  inline static const PrefetchNdInstruction prefetchNdInst;
+  inline static const Instruction *const instructionRegistryArr[] = {
+      &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
   PVCuArch()
       : Xe2Plus("pvc",                        // archName
                 "Ponte Vecchio Architecture", // archDescription
-                XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
-                {/* registerFileInfo */}, // Optional: empty
-                {/* cacheInfo */},        // Optional: empty
-                {/* instructions */}      // Optional: empty
-        ) {
-    // Intialize register file info
-    // GRF
-    this->registerFileInfo.emplace(
-        RegisterFileType::GRF,
-        RegisterFileInfo(
-            64 * 1024,                                          // size in bits
-            {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
-            {128, 256} // registers per thread per mode
-            ));
-    // Initialize cache info
-    // L1 cache, XeCore level
-    this->cacheInfo.push_back(
-        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
-    // L2 cache, XeStack level
-    this->cacheInfo.push_back(
-        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
-
-    // Add the instructions-
-    auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getInstructionKind(), dpas);
-    owned_instructions.push_back(dpas);
+                instructionRegistryArr,
+                XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
+        ) {}
+  static const uArch *getInstance() {
+    static const PVCuArch instance;
+    return reinterpret_cast<const uArch *>(&instance);
   }
 };
 
 struct BMGuArch : public Xe2Plus {
-  // Maintaines ownership of the instructions owned by PVUarch
-  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+  inline static const DPASInstruction dpasInst{16, 32, 32};
+  inline static const StoreNdInstruction loadNdInst;
+  inline static const StoreNdInstruction storeNdInst;
+  inline static const PrefetchNdInstruction prefetchNdInst;
+  inline static const Instruction *const instructionRegistryArr[] = {
+      &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
   BMGuArch()
       : Xe2Plus("bmg",                     // archName
                 "Battlemage Architecture", // archDescription
-                XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
-                {/* registerFileInfo */}, // Optional: empty
-                {/* cacheInfo */},        // Optional: empty
-                {/* instructions */}      // Optional: empty
-        ) {
-    // Intialize register file info
-    // GRF
-    this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
-        64 * 1024,                                          // size in bits
-        {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
-        {128, 256} // registers per thread per mode
-    );
-    // Initialize cache info
-    // L1 cache, XeCore level
-    this->cacheInfo.push_back(
-        CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
-    // L2 cache, XeStack level
-    this->cacheInfo.push_back(
-        CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
-
-    // Add the instructions
-    auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getInstructionKind(), dpas);
-    owned_instructions.push_back(dpas);
+                instructionRegistryArr,
+                XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
+        ) {}
+  static const uArch *getInstance() {
+    static const BMGuArch instance;
+    return reinterpret_cast<const uArch *>(&instance);
   }
 };
+
+inline const uArch *getUArch(llvm::StringRef archName) {
+  if (archName.equals_insensitive("pvc"))
+    return PVCuArch::getInstance();
+  else if (archName.equals_insensitive("bmg"))
+    return BMGuArch::getInstance();
+
+  return nullptr;
+}
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
 inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
 DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
@@ -257,12 +314,12 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) {
+DPASInstruction::getSupportedM(Type type) const {
   return {1, 2, 3, 4, 5, 6, 7, 8};
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) {
+DPASInstruction::getSupportedK(Type type) const {
   // assert if data type is not int or float type
   assert(type.isIntOrFloat() && "Matrix type must be int or float");
   auto bitWidth = type.getIntOrFloatBitWidth();
@@ -290,7 +347,7 @@ DPASInstruction::getSupportedK(Type type) {
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) {
+DPASInstruction::getSupportedN(Type type) const {
   return {16};
 }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..f5844d62f374c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
-  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
-        // multiply-add operation
+  DPAS,       // Dot Product Accumulate Systolic (DPAS) is a matrix
+              // multiply-add operation
+  STORE_ND,   // Subgroup-level 2D block write instruction
+  LOAD_ND,    // Subgroup-level 2D block load instruction
+  PREFETCH_ND // Subgroup-level 2D block prefetch instruction
   // @TODO: Add more instructions as needed
 };
 
@@ -48,12 +51,18 @@ struct Instruction {
 
   virtual ~Instruction() = default;
   // Get methods
-  InstructionKind getInstructionKind() { return instKind; }
-  InstructionScope getScope() { return scope; }
+  InstructionKind getInstructionKind() const { return instKind; }
+  InstructionScope getScope() const { return scope; }
   static llvm::StringRef toString(InstructionKind instKind) {
     switch (instKind) {
     case InstructionKind::DPAS:
       return "dpas";
+    case InstructionKind::STORE_ND:
+      return "store_nd";
+    case InstructionKind::LOAD_ND:
+      return "load_nd";
+    case InstructionKind::PREFETCH_ND:
+      return "prefetch_nd";
     }
     llvm_unreachable("Unknown InstructionKind");
   }
@@ -66,9 +75,9 @@ struct Instruction {
   }
 
 protected:
-  InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
-  InstructionScope scope;   // scope of the instruction (e.g., lane, subgroup,
-                            // workgroup, cluster)
+  const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+  const InstructionScope scope;   // scope of the instruction (e.g., lane,
+                                  // subgroup, workgroup, cluster)
   // @TODO: Add more fields as needed
 };
 
@@ -129,61 +138,37 @@ struct CacheInfo {
   // latency, throughput, bandwidth)
 };
 
-// A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device.
 struct uArch {
   // Constructor
-  uArch(
-      const std::string &name, const std::string &description,
-      const std::map<RegisterFileType, RegisterFileInfo> &registerFileInfo = {},
-      const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
-      const std::map<InstructionKind, std::shared_ptr<Instruction>>
-          &instructions = {})
-      : name(name), description(description),
-        registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
-        instructions(instructions) {}
-
-  // Get methods
-  const std::string &getName() const { return name; }
-
-  const std::string &getDescription() const { return description; }
-
-  const std::map<RegisterFileType, RegisterFileInfo> &
-  getRegisterFileInfo() const {
-    return registerFileInfo;
-  }
-
-  const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
-    return cacheInfo;
-  }
-
-  const std::map<InstructionKind, std::shared_ptr<Instruction>> &
-  getInstructions() const {
-    return instructions;
+  uArch(StringRef name, StringRef description,
+        llvm::ArrayRef<const Instruction *> instructionRegistry)
+      : name(name), description(description) {
+    for (const Instruction *instr : instructionRegistry)
+      this->instructionRegistry[instr->getInstructionKind()] = instr;
   }
-
-  // Get the name of the supported instruction names for that
-  // architecture. It returns the names of the instructions added to the uArch.
-  llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
-    llvm::SmallVector<StringRef, 8> instructionNames;
-    for (const auto &inst : instructions) {
-      instructionNames.push_back(Instruction::toString(inst.first));
-    }
-    return instructionNames;
+  virtual ~uArch() = default;
+  StringRef getName() const { return name; }
+  StringRef getDescription() const { return description; }
+  virtual int getSubgroupSize() const = 0;
+  virtual unsigned getPackedFormatBitSize() const = 0;
+  virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
+
+  const Instruction *getInstruction(InstructionKind instKind) const {
+    auto it = instructionRegistry.find(instKind);
+    assert(it != instructionRegistry.end() &&
+           "Instruction not found in registry");
+    return it->second;
   }
 
-  // Checks if an instruction is supported in this uArch
-  bool checkSupportedInstruction(InstructionKind instr) const {
-    return instructions.find(instr) != instructions.end();
+  bool isSupportedInstruction(InstructionKind instr) const {
+    return instructionRegistry.contains(instr);
   }
 
 protected:
-  std::string name; // Name of the uArch, similar to target triple
-  std::string description;
-  std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
-  llvm::SmallVector<CacheInfo, 4> cacheInfo;
-  std::map<InstructionKind, std::shared_ptr<Instruction>>
-      instructions; // set of instructions supported by the uArch
+  StringRef name;
+  StringRef description;
+  llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
+      instructionRegistry;
 };
 
 // A struct to represent shared memory information
@@ -251,9 +236,9 @@ struct MMAInstructionInterface {
                         std::pair<uint32_t, uint32_t> CShape,
                         std::pair<uint32_t, uint32_t> DShape, Type AType,
                         Type BType, Type CType, Type DType) = 0;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
-  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) const = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) const = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) const = 0;
 
   virtual ~MMAInstructionInterface() = default;
 };

>From 92c80cf7c9f995b7e9bd424304c94e6e1bff3453 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 17 Oct 2025 16:34:50 +0000
Subject: [PATCH 2/5] Fix warning

---
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index f5844d62f374c..82a5223c43651 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -49,7 +49,7 @@ struct Instruction {
   Instruction(InstructionKind kind, InstructionScope scope)
       : instKind(kind), scope(scope) {}
 
-  virtual ~Instruction() = default;
+  ~Instruction() = default;
   // Get methods
   InstructionKind getInstructionKind() const { return instKind; }
   InstructionScope getScope() const { return scope; }

>From 1bf2b652676a702aeb81dcc6f0443096bc834ee2 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 22 Oct 2025 12:59:53 +0000
Subject: [PATCH 3/5] Fix warnings

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 51 ++++++++++++-------
 1 file changed, 34 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index f264be5181b2a..f00644b9dea09 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -49,7 +49,9 @@ struct Xe2Plus : public uArch {
 struct StoreNdInstruction : public Instruction {
   StoreNdInstruction()
       : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
-
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::STORE_ND;
+  }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
   // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
@@ -63,7 +65,9 @@ struct StoreNdInstruction : public Instruction {
 struct LoadNdInstruction : public Instruction {
   LoadNdInstruction()
       : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
-
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LOAD_ND;
+  }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
   // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
@@ -77,7 +81,9 @@ struct LoadNdInstruction : public Instruction {
 struct PrefetchNdInstruction : public Instruction {
   PrefetchNdInstruction()
       : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
-
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::PREFETCH_ND;
+  }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
   llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
@@ -103,6 +109,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
         packedFormatBitSizeA(packedFormatBitSizeA),
         packedFormatBitSizeB(packedFormatBitSizeB),
         packedFormatBitSizeC(packedFormatBitSizeC) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::DPAS;
+  }
   // Source:
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
 
@@ -146,16 +155,20 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 //===----------------------------------------------------------------------===//
 
 struct PVCuArch final : public Xe2Plus {
-  inline static const DPASInstruction dpasInst{16, 32, 32};
-  inline static const StoreNdInstruction loadNdInst;
-  inline static const StoreNdInstruction storeNdInst;
-  inline static const PrefetchNdInstruction prefetchNdInst;
-  inline static const Instruction *const instructionRegistryArr[] = {
-      &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
+  static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+    static const DPASInstruction dpasInst{16, 32, 32};
+    static const StoreNdInstruction loadNdInst;
+    static const StoreNdInstruction storeNdInst;
+    static const PrefetchNdInstruction prefetchNdInst;
+    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+                                       &prefetchNdInst};
+    return arr;
+  }
+
   PVCuArch()
       : Xe2Plus("pvc",                        // archName
                 "Ponte Vecchio Architecture", // archDescription
-                instructionRegistryArr,
+                getInstructionRegistryArr(),
                 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
         ) {}
   static const uArch *getInstance() {
@@ -165,16 +178,20 @@ struct PVCuArch final : public Xe2Plus {
 };
 
 struct BMGuArch : public Xe2Plus {
-  inline static const DPASInstruction dpasInst{16, 32, 32};
-  inline static const StoreNdInstruction loadNdInst;
-  inline static const StoreNdInstruction storeNdInst;
-  inline static const PrefetchNdInstruction prefetchNdInst;
-  inline static const Instruction *const instructionRegistryArr[] = {
-      &dpasInst, &loadNdInst, &storeNdInst, &prefetchNdInst};
+  static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+    static const DPASInstruction dpasInst{16, 32, 32};
+    static const StoreNdInstruction loadNdInst;
+    static const StoreNdInstruction storeNdInst;
+    static const PrefetchNdInstruction prefetchNdInst;
+    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+                                       &prefetchNdInst};
+    return arr;
+  }
+
   BMGuArch()
       : Xe2Plus("bmg",                     // archName
                 "Battlemage Architecture", // archDescription
-                instructionRegistryArr,
+                getInstructionRegistryArr(),
                 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
         ) {}
   static const uArch *getInstance() {

>From e834a9fb0690c4c25e1c2d7e3b4f08df166eec6a Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 26 Oct 2025 08:54:49 +0000
Subject: [PATCH 4/5] Update instructions

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 116 +++++++++++++-----
 1 file changed, 87 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index f00644b9dea09..7e9e714c6007a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -53,12 +53,26 @@ struct StoreNdInstruction : public Instruction {
     return B->getInstructionKind() == InstructionKind::STORE_ND;
   }
   // Source :
-  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
   // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
   // the specified pointer
-  llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
-    const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
-    return sortedLaneVectorLengths;
+  std::optional<
+      std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+  getBlockWidthHeightCount(Type elemTy) const {
+    const static int kHeight[] = {1, 2, 4, 8};
+    const static int kWidth16[] = {16};
+    const static int kWidth32[] = {16};
+    const static int kCount[] = {1};
+    const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+    if (elemByteSize == 1)
+      return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
+                             llvm::ArrayRef<int>(kHeight),
+                             llvm::ArrayRef<int>(kCount));
+    else if (elemByteSize == 2 || elemByteSize == 4)
+      return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
+                             llvm::ArrayRef<int>(kHeight),
+                             llvm::ArrayRef<int>(kCount));
+    return std::nullopt;
   }
 };
 
@@ -68,13 +82,49 @@ struct LoadNdInstruction : public Instruction {
   static bool classof(const Instruction *B) {
     return B->getInstructionKind() == InstructionKind::LOAD_ND;
   }
+
   // Source :
-  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
   // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
   // the specified pointer.
-  llvm::ArrayRef<int> getSortedLaneVectorLengths() const {
-    const static int sortedLaneVectorLengths[] = {1, 2, 4, 8};
-    return sortedLaneVectorLengths;
+  std::optional<
+      std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+  getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
+                           bool upConv = false) const {
+    static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+    static const int kHeightAtLeast8[] = {8, 16, 32};
+    static const int kHeightAtLeast16[] = {16, 32};
+    static const int kHeightAtLeast32[] = {32};
+
+    static const int kWidth32[] = {32};
+    static const int kWidth16[] = {16};
+    static const int kWidth8[] = {8};
+
+    static const int32_t kCount1[] = {1};
+    static const int32_t kCount2[] = {1, 2};
+    static const int32_t kCount4[] = {1, 2, 4};
+    static const int32_t kCount4Only[] = {4};
+    // (elemBytes, transform, transpose, upConvert)
+    using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
+    // (widths, heights, counts)
+    using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+                             llvm::ArrayRef<int32_t>>;
+    static const llvm::DenseMap<Key, Value> kMap = {
+        {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
+        {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
+        {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
+        {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
+        // Block Loads with Transform:
+        {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
+        {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
+        // Block Loads with Transpose:
+        {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
+    };
+    const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+    auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
+    if (it != kMap.end())
+      return it->second;
+    return std::nullopt;
   }
 };
 
@@ -86,29 +136,39 @@ struct PrefetchNdInstruction : public Instruction {
   }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
-  llvm::ArrayRef<int> getSortedLaneVectorLengths(int elementBitwidth) const {
-    const static int sortedNarrowTypesLengths[] = {1, 2, 4, 8, 16};
-    const static int sortedWideTypesLengths[] = {1, 2, 4, 8};
-    switch (elementBitwidth) {
-    case 8:
-    case 16:
-      return sortedNarrowTypesLengths;
-    case 32:
-    case 64:
-      return sortedWideTypesLengths;
-    default:
-      llvm_unreachable("Unsupported element bitwidth");
-    }
+  std::optional<
+      std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+  getBlockWidthHeightCount(Type elemTy) const {
+    static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+
+    static const int kWidth32[] = {32};
+    static const int kWidth16[] = {16};
+
+    static const int32_t kCount1[] = {1};
+    static const int32_t kCount2[] = {1, 2};
+    // elemBytes
+    using Key = int;
+    // (widths, heights, counts)
+    using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+                             llvm::ArrayRef<int32_t>>;
+    static const llvm::DenseMap<Key, Value> kMap = {
+        {1, {kWidth32, kHeightAtLeast1, kCount2}},
+        {2, {kWidth16, kHeightAtLeast1, kCount2}},
+        {4, {kWidth16, kHeightAtLeast1, kCount1}},
+    };
+    const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+    auto it = kMap.find(elemByteSize);
+    if (it != kMap.end())
+      return it->second;
+    return std::nullopt;
   }
 };
 
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
-  DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB,
-                  unsigned packedFormatBitSizeC)
+  DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
       : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
         packedFormatBitSizeA(packedFormatBitSizeA),
-        packedFormatBitSizeB(packedFormatBitSizeB),
-        packedFormatBitSizeC(packedFormatBitSizeC) {}
+        packedFormatBitSizeB(packedFormatBitSizeB) {}
   static bool classof(const Instruction *B) {
     return B->getInstructionKind() == InstructionKind::DPAS;
   }
@@ -142,12 +202,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 
   unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
   unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
-  unsigned getPackedFormatBitSizeC() const { return packedFormatBitSizeC; }
 
 protected:
   const unsigned packedFormatBitSizeA;
   const unsigned packedFormatBitSizeB;
-  const unsigned packedFormatBitSizeC;
 };
 
 //===----------------------------------------------------------------------===//
@@ -156,7 +214,7 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 
 struct PVCuArch final : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
-    static const DPASInstruction dpasInst{16, 32, 32};
+    static const DPASInstruction dpasInst{16, 32};
     static const StoreNdInstruction loadNdInst;
     static const StoreNdInstruction storeNdInst;
     static const PrefetchNdInstruction prefetchNdInst;
@@ -179,7 +237,7 @@ struct PVCuArch final : public Xe2Plus {
 
 struct BMGuArch : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
-    static const DPASInstruction dpasInst{16, 32, 32};
+    static const DPASInstruction dpasInst{16, 32};
     static const StoreNdInstruction loadNdInst;
     static const StoreNdInstruction storeNdInst;
     static const PrefetchNdInstruction prefetchNdInst;

>From 32a2531b73a6e9f79693c913b444e51040ab8e38 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sun, 26 Oct 2025 09:27:18 +0000
Subject: [PATCH 5/5] Address feedback

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 98 ++++++++++---------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 23 +++--
 2 files changed, 64 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 7e9e714c6007a..dcb2ad5d67a25 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -36,8 +36,7 @@ struct Xe2Plus : public uArch {
           const XeCoreInfo &xeCore)
       : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
   int getSubgroupSize() const override { return 16; }
-  unsigned getPackedFormatBitSize() const override { return 16; }
-  unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
+  unsigned getGeneralPackedFormatBitSize() const override { return 32; }
 
 protected:
   XeCoreInfo xeCore;
@@ -46,16 +45,15 @@ struct Xe2Plus : public uArch {
 //===----------------------------------------------------------------------===//
 // uArch instructions
 //===----------------------------------------------------------------------===//
-struct StoreNdInstruction : public Instruction {
-  StoreNdInstruction()
-      : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockStoreInstruction : public Instruction {
+  Subgroup2DBlockStoreInstruction()
+      : Instruction(InstructionKind::Subgroup2DBlockStore,
+                    InstructionScope::Subgroup) {}
   static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::STORE_ND;
+    return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
   }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
-  // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
-  // the specified pointer
   std::optional<
       std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
   getBlockWidthHeightCount(Type elemTy) const {
@@ -74,19 +72,20 @@ struct StoreNdInstruction : public Instruction {
                              llvm::ArrayRef<int>(kCount));
     return std::nullopt;
   }
+
+  int32_t getPackedFormatBitSize() const { return 16; }
 };
 
-struct LoadNdInstruction : public Instruction {
-  LoadNdInstruction()
-      : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockLoadInstruction : public Instruction {
+  Subgroup2DBlockLoadInstruction()
+      : Instruction(InstructionKind::Subgroup2DBlockLoad,
+                    InstructionScope::Subgroup) {}
   static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::LOAD_ND;
+    return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
   }
 
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
-  // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
-  // the specified pointer.
   std::optional<
       std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
   getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
@@ -126,13 +125,16 @@ struct LoadNdInstruction : public Instruction {
       return it->second;
     return std::nullopt;
   }
+
+  int32_t getPackedFormatBitSize() const { return 16; }
 };
 
-struct PrefetchNdInstruction : public Instruction {
-  PrefetchNdInstruction()
-      : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+struct Subgroup2DBlockPrefetchInstruction : public Instruction {
+  Subgroup2DBlockPrefetchInstruction()
+      : Instruction(InstructionKind::Subgroup2DBlockPrefetch,
+                    InstructionScope::Subgroup) {}
   static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::PREFETCH_ND;
+    return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
   }
   // Source :
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
@@ -162,15 +164,20 @@ struct PrefetchNdInstruction : public Instruction {
       return it->second;
     return std::nullopt;
   }
+  int32_t getPackedFormatBitSize() const { return 16; }
 };
 
-struct DPASInstruction : public Instruction, public MMAInstructionInterface {
-  DPASInstruction(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
-      : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup),
+struct SubgroupMatrixMultiplyAcc : public Instruction,
+                                   public MMAInstructionInterface {
+  SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
+                            unsigned packedFormatBitSizeB)
+      : Instruction(InstructionKind::SubgroupMatrixMultiplyAcc,
+                    InstructionScope::Subgroup),
         packedFormatBitSizeA(packedFormatBitSizeA),
         packedFormatBitSizeB(packedFormatBitSizeB) {}
   static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::DPAS;
+    return B->getInstructionKind() ==
+           InstructionKind::SubgroupMatrixMultiplyAcc;
   }
   // Source:
   // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
@@ -214,10 +221,10 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 
 struct PVCuArch final : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
-    static const DPASInstruction dpasInst{16, 32};
-    static const StoreNdInstruction loadNdInst;
-    static const StoreNdInstruction storeNdInst;
-    static const PrefetchNdInstruction prefetchNdInst;
+    static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+    static const Subgroup2DBlockLoadInstruction loadNdInst;
+    static const Subgroup2DBlockStoreInstruction storeNdInst;
+    static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
     static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
                                        &prefetchNdInst};
     return arr;
@@ -237,10 +244,10 @@ struct PVCuArch final : public Xe2Plus {
 
 struct BMGuArch : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
-    static const DPASInstruction dpasInst{16, 32};
-    static const StoreNdInstruction loadNdInst;
-    static const StoreNdInstruction storeNdInst;
-    static const PrefetchNdInstruction prefetchNdInst;
+    static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+    static const Subgroup2DBlockLoadInstruction loadNdInst;
+    static const Subgroup2DBlockStoreInstruction storeNdInst;
+    static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
     static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
                                        &prefetchNdInst};
     return arr;
@@ -276,7 +283,8 @@ inline const uArch *getUArch(llvm::StringRef archName) {
 //===----------------------------------------------------------------------===//
 
 inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
-DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedShapes(Type dataType,
+                                              MMAOpndKind matrixType) {
   auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
                            const llvm::SmallVector<uint32_t, 8> &b)
       -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
@@ -312,8 +320,8 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
 }
 
 inline llvm::SmallVector<Type, 8>
-DPASInstruction::getSupportedTypes(MLIRContext &context,
-                                   MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedTypes(MLIRContext &context,
+                                             MMAOpndKind matrixType) {
   Type bf16Type = BFloat16Type::get(&context);
   Type f16Type = Float16Type::get(&context);
   Type tf32Type = FloatTF32Type::get(&context);
@@ -332,8 +340,10 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
   return {};
 }
 
-inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
-                                                 Type CType, Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedTypes(Type AType,
+                                                           Type BType,
+                                                           Type CType,
+                                                           Type DType) {
   if (AType.isF16() || BType.isF16()) {
     if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
         (!DType.isF32() && !DType.isF16())) {
@@ -363,7 +373,7 @@ inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
   return true;
 }
 
-inline bool DPASInstruction::checkSupportedShapesAndTypes(
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedShapesAndTypes(
     std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
     std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
     Type AType, Type BType, Type CType, Type DType) {
@@ -378,23 +388,21 @@ inline bool DPASInstruction::checkSupportedShapesAndTypes(
          checkSupportedTypes(AType, BType, CType, DType);
 }
 
-inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
-                                      std::pair<uint32_t, uint32_t> BShape,
-                                      std::pair<uint32_t, uint32_t> CShape,
-                                      std::pair<uint32_t, uint32_t> DShape,
-                                      Type AType, Type BType, Type CType,
-                                      Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::validate(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    Type AType, Type BType, Type CType, Type DType) {
   return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
                                       BType, CType, DType);
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedM(Type type) const {
   return {1, 2, 3, 4, 5, 6, 7, 8};
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedK(Type type) const {
   // assert if data type is not int or float type
   assert(type.isIntOrFloat() && "Matrix type must be int or float");
   auto bitWidth = type.getIntOrFloatBitWidth();
@@ -422,7 +430,7 @@ DPASInstruction::getSupportedK(Type type) const {
 }
 
 inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) const {
+SubgroupMatrixMultiplyAcc::getSupportedN(Type type) const {
   return {16};
 }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 82a5223c43651..ea33e885c78ff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,11 +32,11 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
-  DPAS,       // Dot Product Accumulate Systolic (DPAS) is a matrix
-              // multiply-add operation
-  STORE_ND,   // Subgroup-level 2D block write instruction
-  LOAD_ND,    // Subgroup-level 2D block load instruction
-  PREFETCH_ND // Subgroup-level 2D block prefetch instruction
+  SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
+                             // matrix multiply-add operation
+  Subgroup2DBlockStore,      // Subgroup-level 2D block write instruction
+  Subgroup2DBlockLoad,       // Subgroup-level 2D block load instruction
+  Subgroup2DBlockPrefetch    // Subgroup-level 2D block prefetch instruction
   // @TODO: Add more instructions as needed
 };
 
@@ -55,13 +55,13 @@ struct Instruction {
   InstructionScope getScope() const { return scope; }
   static llvm::StringRef toString(InstructionKind instKind) {
     switch (instKind) {
-    case InstructionKind::DPAS:
+    case InstructionKind::SubgroupMatrixMultiplyAcc:
       return "dpas";
-    case InstructionKind::STORE_ND:
+    case InstructionKind::Subgroup2DBlockStore:
       return "store_nd";
-    case InstructionKind::LOAD_ND:
+    case InstructionKind::Subgroup2DBlockLoad:
       return "load_nd";
-    case InstructionKind::PREFETCH_ND:
+    case InstructionKind::Subgroup2DBlockPrefetch:
       return "prefetch_nd";
     }
     llvm_unreachable("Unknown InstructionKind");
@@ -70,7 +70,7 @@ struct Instruction {
   static std::optional<InstructionKind>
   parseInstructionKind(llvm::StringRef str) {
     if (str.equals_insensitive("dpas"))
-      return InstructionKind::DPAS;
+      return InstructionKind::SubgroupMatrixMultiplyAcc;
     return std::nullopt;
   }
 
@@ -150,8 +150,7 @@ struct uArch {
   StringRef getName() const { return name; }
   StringRef getDescription() const { return description; }
   virtual int getSubgroupSize() const = 0;
-  virtual unsigned getPackedFormatBitSize() const = 0;
-  virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
+  virtual unsigned getGeneralPackedFormatBitSize() const = 0;
 
   const Instruction *getInstruction(InstructionKind instKind) const {
     auto it = instructionRegistry.find(instKind);



More information about the Mlir-commits mailing list