[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Wed Oct 8 06:28:43 PDT 2025


https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/153706

>From b41756194fb488e9cbf7664233c2361fbd60ed8e Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Fri, 28 Mar 2025 13:04:30 +0000
Subject: [PATCH 01/13] [uArch][XeGPU] Add XeGPU uArch definition.

The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary components
  (e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops
  (e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces
  are provided by the specific instructions. Each family of ops provides
  these 5 common APIs. However, some family of ops may have more
  utility APIs. The common 5 APIs are:
	- getSupportedShapes
	- getSupportedTypes
	- checkSupportedShapesAndTypes
	- checkSupportedTypes
	- validate

Add support for PVC and BMG architectures.
Add support for DPAS instruction.
---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 182 ++++++++++++
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 266 ++++++++++++++++++
 .../Dialect/XeGPU/uArch/uArchInterfaces.h     |  75 +++++
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |   1 +
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp    |   9 +
 mlir/lib/Dialect/XeGPU/CMakeLists.txt         |   1 +
 mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt      |   1 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |   9 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt   |  11 +
 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp  | 197 +++++++++++++
 11 files changed, 753 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
 create mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..9179838f8c148
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,182 @@
+//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Xe2 uArch definition.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+struct XeCoreInfo {
+  uint32_t num_threads;
+  SharedMemory shared_memory;
+  uint32_t num_vector_units;
+  uint32_t num_matrix_units;
+
+  // Constructor
+  XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+             uint32_t num_vector_units, uint32_t num_matrix_units)
+      : num_threads(num_threads), shared_memory(shared_memory),
+        num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+  }
+};
+
+struct Xe2Plus : public uArch {
+  XeCoreInfo xe_core;
+  Xe2Plus(
+      const std::string &archName, const std::string &archDescription,
+      const XeCoreInfo &xeCore,
+      const std::vector<uArchHierarchyComponent> &hierarchy = {},
+      const std::map<std::string, RegisterFileInfo> &regInfo = {},
+      const std::vector<CacheInfo> &cacheInfo = {},
+      const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+      : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+        xe_core(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+  DPASInstruction()
+      : Instruction("dpas",                   // name
+                    "Dot Product Accumulate") // description
+  {}
+
+  // Override all virtuals from MatrixOpInterface
+  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+  virtual std::vector<mlir::Type>
+  getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape,
+                               mlir::Type AType, mlir::Type BType,
+                               mlir::Type CType, mlir::Type DType) override;
+  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                   mlir::Type CType, mlir::Type DType) override;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+                        mlir::Type BType, mlir::Type CType,
+                        mlir::Type DType) override;
+  virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
+  virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
+  virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+};
+
+namespace PVCuArch {
+struct PVCuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  PVCuArch()
+      : Xe2Plus("pvc",                        // archName
+                "Ponte Vecchio Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+                {/* register_file_info */}, // Optional: empty
+                {/* cache_info */},         // Optional: empty
+                {/* instructions */}        // Optional: empty
+        ) {
+    // Initialize uArchHierarchy
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
+    // Intialize register file info
+    // GRF
+    this->register_file_info.emplace(
+        "GRF",
+        RegisterFileInfo(64 * 1024,          // size in bits
+                         {"small", "large"}, // GRF modes
+                         {128, 256},         // registers per thread per mode
+                         0,                  // number of banks
+                         0                   // bank size
+                         ));
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cache_info.push_back(
+        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
+    // L3 cache, XeStack level
+    this->cache_info.push_back(
+        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+
+    // Add the instructions
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getName(), dpas);
+    // instructions[dpas->name] = dpas.get();
+    owned_instructions.push_back(dpas);
+  }
+};
+} // namespace PVCuArch
+
+namespace BMGuArch {
+struct BMGuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  BMGuArch()
+      : Xe2Plus("bmg",                     // archName
+                "Battlemage Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+                {/* register_file_info */}, // Optional: empty
+                {/* cache_info */},         // Optional: empty
+                {/* instructions */},       // Optional: empty
+                {/* restrictions */}        // Optional: empty
+        ) {
+    // Initialize uArchHierarchy
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
+    // Intialize register file info
+    // GRF
+    this->register_file_info["GRF"] =
+        RegisterFileInfo(64 * 1024,          // size in bits
+                         {"small", "large"}, // GRF modes
+                         {128, 256},         // registers per thread per mode
+                         0,                  // number of banks
+                         0                   // bank size
+        );
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cache_info.push_back(
+        CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
+    // L3 cache, XeStack level
+    this->cache_info.push_back(
+        CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+
+    // Add the instructions
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getName(), dpas);
+    // instructions[dpas->name] = dpas.get();
+    owned_instructions.push_back(dpas);
+  }
+};
+} // namespace BMGuArch
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..9bda86df2aff9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,266 @@
+//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Base uArch definition for different architectures.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+// Architecture HW component hierarchy to present thread, core, socket ...
+struct uArchHierarchyComponent {
+  std::string name = ""; // optional name of the hierarchy component
+  // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
+  // contains 8 threads, so no_of_component=8
+  uint32_t no_of_component;
+  // Constructor
+  uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
+      : name(name), no_of_component(no_of_component) {}
+};
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+
+// A struct to represent basic information about an instruction
+// This struct is used to represent the information about an instruction in the
+// uArch The information includes:
+// - the name of the instruction,
+// - the description of the instruction
+// - the scope of the instruction,
+//
+// The information is represented as strings
+// For example, the information about an instruction can be represented as:
+// Instruction instr = {"dpas", "Dot Product Accumulate Systolic  (DPAS) is a
+// matrix multiply-add operation", "subgroup"};
+
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed
+
+struct Instruction {
+  // @TODO: Add more fields as needed
+  Instruction(std::string name, std::string desc)
+      : name(std::move(name)), description(std::move(desc)) {}
+
+  virtual ~Instruction() = default;
+  // Get methods
+  std::string getName() { return name; }
+  std::string getDescription() { return description; }
+  InstructionScopeEnum getScope() { return scope; }
+
+protected:
+  std::string name;
+  std::string description;
+  InstructionScopeEnum scope;
+};
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+  // Constructor
+  RegisterFileInfo() = default;
+  RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
+                   const std::vector<uint32_t> &numRegs, uint32_t num_banks,
+                   uint32_t bank_size)
+      : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
+        num_banks(num_banks), bank_size(bank_size) {}
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+
+  const std::vector<std::string> &getModes() const { return mode; }
+
+  const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+    return num_regs_per_thread_per_mode;
+  }
+
+  uint32_t getNumBanks() const { return num_banks; }
+
+  uint32_t getBankSize() const { return bank_size; }
+
+protected:
+  uint32_t size;                 // size per register in bits
+  std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+  std::vector<uint32_t>
+      num_regs_per_thread_per_mode; // number of registers per thread per mode
+  uint32_t num_banks;
+  uint32_t bank_size;
+};
+
+// A struct to represent cache information
+
+struct CacheInfo {
+  // Constructor
+  CacheInfo(uint32_t size, uint32_t line_size,
+            const uArchHierarchyComponent &component)
+      : size(size), line_size(line_size), component(component) {}
+
+  virtual ~CacheInfo() = default;
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+  uint32_t getLineSize() const { return line_size; }
+  const uArchHierarchyComponent &getComponent() const { return component; }
+
+protected:
+  uint32_t size;
+  uint32_t line_size;
+  // At which component level the cache is shared
+  uArchHierarchyComponent component;
+
+  // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+  // bank_size, num_ports, port_width, bank_conflicts)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device
+// The uArch includes:
+// - the name of the uArch,
+// - the description of the uArch,
+// - uArch hierarchy
+// - Rgister File information
+// - Cache information
+// - the set of instructions supported by the uArch,
+struct uArch {
+  // Constructor
+  uArch() = default;
+  uArch(const std::string &name, const std::string &description,
+        const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
+        const std::map<std::string, RegisterFileInfo> &register_file_info = {},
+        const std::vector<CacheInfo> &cache_info = {},
+        const std::map<std::string, std::shared_ptr<Instruction>>
+            &instructions = {})
+      : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
+        register_file_info(register_file_info), cache_info(cache_info),
+        instructions(instructions) {}
+
+  // Get methods
+  const std::string &getName() const { return name; }
+
+  const std::string &getDescription() const { return description; }
+
+  const std::vector<uArchHierarchyComponent> &getHierarchy() const {
+    return uArch_hierarchy;
+  }
+
+  const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
+    return register_file_info;
+  }
+
+  const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+
+  const std::map<std::string, std::shared_ptr<Instruction>> &
+  getInstructions() const {
+    return instructions;
+  }
+
+  // Get the name of the supported instruction names for that
+  // architecture. It returns the names of the instructions added to the uArch.
+  std::vector<std::string> getSupportedInstructionNames() const {
+    std::vector<std::string> instructionNames;
+    for (const auto &inst : instructions) {
+      instructionNames.push_back(inst.first);
+    }
+    return instructionNames;
+  }
+
+  // Checks if an instruction is supported in this uArch
+  bool checkSupportedInstruction(const std::string &instructionName) const {
+    return instructions.find(instructionName) != instructions.end();
+  }
+
+protected:
+  std::string name; // Similar to target triple
+  std::string description;
+  std::vector<uArchHierarchyComponent> uArch_hierarchy;
+  std::map<std::string, RegisterFileInfo> register_file_info;
+  std::vector<CacheInfo> cache_info;
+  std::map<std::string, std::shared_ptr<Instruction>> instructions;
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+  // Constructor
+  SharedMemory(uint32_t size, uint32_t alignment)
+      : size(size), alignment(alignment) {}
+
+  // Getters
+  uint32_t getSize() const { return size; }
+  uint32_t getAlignment() const { return alignment; }
+
+protected:
+  uint32_t size;      // in bytes
+  uint32_t alignment; // in bytes
+  // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct uArchMap {
+public:
+  // Singleton instance
+  static uArchMap &instance() {
+    static uArchMap instance;
+    return instance;
+  }
+
+  // Insert or update a key-value pair
+  void insert(const std::string &key, std::shared_ptr<uArch> value) {
+    std::unique_lock<std::shared_mutex> lock(mutex_);
+    // map_[key] = std::move(value); // safe to overwrite
+    map_.emplace(key, std::move(value)); // safe to overwrite
+  }
+
+  // Get a value by key (concurrent safe read)
+  std::shared_ptr<uArch> get(const std::string &key) const {
+    std::shared_lock<std::shared_mutex> lock(mutex_);
+    auto it = map_.find(key);
+    if (it != map_.end())
+      return it->second;
+    return nullptr;
+  }
+
+  // Check if a key exists
+  bool contains(const std::string &key) const {
+    std::shared_lock<std::shared_mutex> lock(mutex_);
+    return map_.find(key) != map_.end();
+  }
+
+  // Remove a key
+  bool erase(const std::string &key) {
+    std::unique_lock<std::shared_mutex> lock(mutex_);
+    return map_.erase(key) > 0;
+  }
+
+private:
+  uArchMap() = default;
+  uArchMap(const uArchMap &) = delete;
+  uArchMap &operator=(const uArchMap &) = delete;
+
+  mutable std::shared_mutex mutex_;
+  std::map<std::string, std::shared_ptr<uArch>> map_;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
new file mode 100644
index 0000000000000..27d44c38317a1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -0,0 +1,75 @@
+//===--- uArchInterfaces.h ---*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Defines the utility interfaces that are implemented by individual
+/// instructions.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+  // Get supported Matrix shapes
+  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
+
+  // @TODO: This method takes an context object as a parameter, this is to
+  // create the mlir::Type objects from the same context. Since type objects are
+  // uniqued in a specific context, to do things like "aType == bType" (where
+  // aType and bType are both same type) kind of checks, the both types should
+  // be from the same context.
+  //
+  // One alternative to this is to create enum to represent each types, but this
+  // adds an extra burden to user to convert these enums to specific types. In
+  // fact the utility that would convert enumToType() and vice versa would still
+  // have to use the context object.
+  //
+  // Untill we have a better solution, we stick to passing context object to
+  // this method.
+  virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
+                                                    MMAOpndEnum matrixType) = 0;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape,
+                               mlir::Type AType, mlir::Type BType,
+                               mlir::Type CType, mlir::Type DType) = 0;
+  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                   mlir::Type CType, mlir::Type DType) = 0;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+                        mlir::Type BType, mlir::Type CType,
+                        mlir::Type DType) = 0;
+  virtual std::vector<uint32_t> getSupportedM(mlir::Type type) = 0;
+  virtual std::vector<uint32_t> getSupportedK(mlir::Type type) = 0;
+  virtual std::vector<uint32_t> getSupportedN(mlir::Type type) = 0;
+
+  virtual ~MMAInstructionInterface() = default;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac7277e3..9e40d41de1c73 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -129,5 +129,6 @@ add_mlir_dialect_library(MLIRXeVMDialect
   MLIRDialectUtils
   MLIRIR
   MLIRLLVMDialect
+  MLIRXeGPUuArch
   MLIRSideEffectInterfaces
 )
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 04e8836c00359..37ab1fcdd1c0e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -8,6 +8,7 @@
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -380,6 +381,14 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
 }
 
 void XeVMDialect::initialize() {
+  // Populate the uArchMap with the supported target devices
+  auto pvcuArch =
+      std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+  mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+  auto bmguArch =
+      std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+  mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..9079df050ab2b 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(uArch)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7869a28dfed57..e1c51deefbee3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
   MLIRGPUDialect
   MLIRXeVMDialect
   MLIRIR
+  MLIRXeGPUuArch
   MLIRViewLikeInterface
   MLIRVectorDialect
 )
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..359517432d82e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -35,6 +36,14 @@ void XeGPUDialect::initialize() {
 #define GET_ATTRDEF_LIST
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
       >();
+
+  // Populate the uArchMap with the supported target devices
+  auto pvcuArch =
+      std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+  mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+  auto bmguArch =
+      std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+  mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
 }
 
 /// Generates instructions to compute offsets for a subgroup identified by
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..e56fe3a1e193d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRTransforms
   MLIRGPUDialect
   MLIRXeGPUUtils
+  MLIRXeGPUuArch
   MLIRGPUUtils
   MLIRVectorTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
new file mode 100644
index 0000000000000..c7f691cb6dda7
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRXeGPUuArch
+  IntelGpuXe2.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRDialectUtils
+)
+
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
new file mode 100644
index 0000000000000..d80d1439692e8
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -0,0 +1,197 @@
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/YAMLTraits.h"
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace mlir::xegpu::uArch;
+using namespace mlir::xegpu::uArch::Xe2Plus;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+
+std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(mlir::Type dataType,
+                                    MMAOpndEnum matrixType) {
+  auto combineVectors = [](const std::vector<uint32_t> &a,
+                           const std::vector<uint32_t> &b)
+      -> std::vector<std::pair<uint32_t, uint32_t>> {
+    std::vector<std::pair<uint32_t, uint32_t>> result;
+    for (unsigned x : a) {
+      for (unsigned y : b) {
+        result.emplace_back(x, y);
+      }
+    }
+    return result;
+  };
+
+  auto M = getSupportedM(dataType);
+  auto K = getSupportedK(dataType);
+  auto N = getSupportedN(dataType);
+  std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+  switch (matrixType) {
+  case MMAOpndEnum::MatrixA:
+    resultMatrix = combineVectors(M, K);
+    break;
+  case MMAOpndEnum::MatrixB:
+    resultMatrix = combineVectors(K, N);
+    break;
+  case MMAOpndEnum::MatrixC:
+    resultMatrix = combineVectors(M, N);
+    break;
+  case MMAOpndEnum::MatrixD:
+    resultMatrix = combineVectors(M, N);
+    break;
+  }
+  return resultMatrix;
+}
+
+std::vector<mlir::Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+                                   MMAOpndEnum matrixType) {
+  mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
+  mlir::Type f16Type = mlir::Float16Type::get(&context);
+  mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
+  mlir::Type f32Type = mlir::Float32Type::get(&context);
+
+  switch (matrixType) {
+  case MMAOpndEnum::MatrixA:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndEnum::MatrixB:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndEnum::MatrixC:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  case MMAOpndEnum::MatrixD:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  }
+}
+
+bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                          mlir::Type CType, mlir::Type DType) {
+  if (AType.isF16() || BType.isF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+        (!DType.isF32() && !DType.isF16())) {
+      llvm::errs()
+          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+          << "Supported types are:\n"
+          << "  Dst    |   Acc   |   A   |  B  \n"
+          << " f, hf   |  f, hf  |   hf  |  hf \n"
+          << "AType: " << AType << " BType: " << BType << " CType: " << CType
+          << " DType: " << DType;
+      return false;
+    }
+  } else if (AType.isBF16() || BType.isBF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+        (!DType.isF32() && !DType.isBF16())) {
+      llvm::errs()
+          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+          << "Supported types are:\n"
+          << "  Dst    |   Acc   |   A   |  B  \n"
+          << " f, bf   |  f, bf  |   bf  |  bf \n"
+          << "AType: " << AType << " BType: " << BType << " CType: " << CType
+          << " DType: " << DType;
+      return false;
+    }
+  } else if (AType.isTF32() || BType.isTF32()) {
+    if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+        (!DType.isF32())) {
+      llvm::errs()
+          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+          << "Supported types are:\n"
+          << "  Dst    |   Acc   |   A    |   B  \n"
+          << "   f     |    f    |  tf32  |  tf32 \n"
+          << "AType: " << AType << " BType: " << BType << " CType: " << CType
+          << " DType: " << DType;
+      return false;
+    }
+  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+               AType.isInteger(8)) &&
+             !(BType.isInteger(2) || BType.isInteger(4) ||
+               BType.isInteger(8))) {
+    llvm::errs()
+        << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+        << "Supported types are:\n"
+        << "  Dst     |   Acc    |         A           |         B          "
+           " \n"
+        << " ud, d    |  ud,d    |  ub,b,u4,s4,u2,s2   |  ub,b,u4,s4,u2,s2  "
+        << "AType: " << AType << " BType: " << BType << " CType: " << CType
+        << " DType: " << DType;
+    return false;
+  }
+
+  return true;
+}
+
+bool DPASInstruction::checkSupportedShapesAndTypes(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
+  auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
+  auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
+  auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
+  auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
+  return llvm::is_contained(supportedAShapes, AShape) &&
+         llvm::is_contained(supportedBShapes, BShape) &&
+         llvm::is_contained(supportedCShapes, CShape) &&
+         llvm::is_contained(supportedDShapes, DShape) &&
+         checkSupportedTypes(AType, BType, CType, DType);
+}
+
+bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape,
+                               mlir::Type AType, mlir::Type BType,
+                               mlir::Type CType, mlir::Type DType) {
+  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+                                      BType, CType, DType);
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedM(mlir::Type type) {
+  return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedK(mlir::Type type) {
+  // assert if data type is not int or float type
+  assert(type.isIntOrFloat() && "Matrix type must be int or float");
+  auto bitWidth = type.getIntOrFloatBitWidth();
+  uint32_t kSize = 0;
+  switch (bitWidth) {
+  case 2:
+    kSize = 64;
+    break;
+  case 4:
+    kSize = 64;
+    break;
+  case 8:
+    kSize = 32;
+    break;
+  case 16:
+    kSize = 16;
+    break;
+  case 32:
+    kSize = 8;
+    break;
+  default:
+    llvm_unreachable("Invalid int or float");
+  }
+  return {kSize};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedN(mlir::Type type) {
+  return {16};
+}
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir

>From 6dace4bcb469f5813b512aba5cef3df48ebb7a6d Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 26 Aug 2025 00:08:11 +0000
Subject: [PATCH 02/13] Address review comments.

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 15 +++----
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 14 +++----
 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp  | 39 ++++---------------
 3 files changed, 22 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 9179838f8c148..e49fe69c47b2e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -1,4 +1,4 @@
-//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 //
-/// \file
-/// Xe2 uArch definition.
-///
+// \file
+// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
+// This file defines the uArch details for Xe2 and its derived architectures.
+// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
 //
 //===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
-#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
 
 #include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -179,4 +180,4 @@ struct BMGuArch : public Xe2Plus {
 } // namespace xegpu
 } // namespace mlir
 
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 9bda86df2aff9..4ef02f32913f7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -1,4 +1,4 @@
-//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//===--- uArch.h ------------------------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-/// \file
-/// Base uArch definition for different architectures.
-///
+// \file
+// Base uArch definition for different architectures.
+//
 //
 //===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
-#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
 
 #include <any>
 #include <functional>
@@ -263,4 +263,4 @@ struct uArchMap {
 } // namespace xegpu
 } // namespace mlir
 
-#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
index d80d1439692e8..4db4300028c46 100644
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -1,11 +1,11 @@
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "llvm/Support/YAMLTraits.h"
+#include "llvm/Support/DebugLog.h"
 #include <algorithm>
-#include <iostream>
-#include <string>
 #include <vector>
 
+#define DEBUG_TYPE "xegpu-uarch"
+
 using namespace mlir::xegpu::uArch;
 using namespace mlir::xegpu::uArch::Xe2Plus;
 
@@ -80,51 +80,26 @@ bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
   if (AType.isF16() || BType.isF16()) {
     if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
         (!DType.isF32() && !DType.isF16())) {
-      llvm::errs()
-          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
-          << "Supported types are:\n"
-          << "  Dst    |   Acc   |   A   |  B  \n"
-          << " f, hf   |  f, hf  |   hf  |  hf \n"
-          << "AType: " << AType << " BType: " << BType << " CType: " << CType
-          << " DType: " << DType;
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
       return false;
     }
   } else if (AType.isBF16() || BType.isBF16()) {
     if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
         (!DType.isF32() && !DType.isBF16())) {
-      llvm::errs()
-          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
-          << "Supported types are:\n"
-          << "  Dst    |   Acc   |   A   |  B  \n"
-          << " f, bf   |  f, bf  |   bf  |  bf \n"
-          << "AType: " << AType << " BType: " << BType << " CType: " << CType
-          << " DType: " << DType;
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
       return false;
     }
   } else if (AType.isTF32() || BType.isTF32()) {
     if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
         (!DType.isF32())) {
-      llvm::errs()
-          << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
-          << "Supported types are:\n"
-          << "  Dst    |   Acc   |   A    |   B  \n"
-          << "   f     |    f    |  tf32  |  tf32 \n"
-          << "AType: " << AType << " BType: " << BType << " CType: " << CType
-          << " DType: " << DType;
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
       return false;
     }
   } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
                AType.isInteger(8)) &&
              !(BType.isInteger(2) || BType.isInteger(4) ||
                BType.isInteger(8))) {
-    llvm::errs()
-        << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
-        << "Supported types are:\n"
-        << "  Dst     |   Acc    |         A           |         B          "
-           " \n"
-        << " ud, d    |  ud,d    |  ub,b,u4,s4,u2,s2   |  ub,b,u4,s4,u2,s2  "
-        << "AType: " << AType << " BType: " << BType << " CType: " << CType
-        << " DType: " << DType;
+    LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
     return false;
   }
 

>From f33b7f73840a3d0fd2c3c1c725a72da5374c6c56 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 17:10:46 +0000
Subject: [PATCH 03/13] Address review comments.

Simplify the design:
- Remove uArchHierarchyComponent

LLVMize names.
Replace String usage with enum whenever possible.
---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 71 ++++++--------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 93 ++++++++-----------
 .../Dialect/XeGPU/uArch/uArchInterfaces.h     |  7 +-
 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp  | 28 +++---
 4 files changed, 83 insertions(+), 116 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index e49fe69c47b2e..a7ae0bdf4378b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -45,11 +45,10 @@ struct Xe2Plus : public uArch {
   Xe2Plus(
       const std::string &archName, const std::string &archDescription,
       const XeCoreInfo &xeCore,
-      const std::vector<uArchHierarchyComponent> &hierarchy = {},
-      const std::map<std::string, RegisterFileInfo> &regInfo = {},
+      const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
       const std::vector<CacheInfo> &cacheInfo = {},
       const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
-      : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+      : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
         xe_core(xeCore) {}
 };
 
@@ -62,9 +61,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 
   // Override all virtuals from MatrixOpInterface
   virtual std::vector<std::pair<uint32_t, uint32_t>>
-  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+  getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
   virtual std::vector<mlir::Type>
-  getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
   virtual bool
   checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
                                std::pair<uint32_t, uint32_t> BShape,
@@ -97,29 +96,22 @@ struct PVCuArch : public Xe2Plus {
                 {/* cache_info */},         // Optional: empty
                 {/* instructions */}        // Optional: empty
         ) {
-    // Initialize uArchHierarchy
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
     // Intialize register file info
     // GRF
-    this->register_file_info.emplace(
-        "GRF",
-        RegisterFileInfo(64 * 1024,          // size in bits
-                         {"small", "large"}, // GRF modes
-                         {128, 256},         // registers per thread per mode
-                         0,                  // number of banks
-                         0                   // bank size
-                         ));
+    this->registerFileInfo.emplace(
+        RegisterFileType::GRF,
+        RegisterFileInfo(
+            64 * 1024,                                          // size in bits
+            {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+            {128, 256} // registers per thread per mode
+            ));
     // Initialize cache info
     // L1 cache, XeCore level
-    this->cache_info.push_back(
-        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
-    // L3 cache, XeStack level
-    this->cache_info.push_back(
-        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+    this->cacheInfo.push_back(
+        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
+    // L2 cache, XeStack level
+    this->cacheInfo.push_back(
+        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
 
     // Add the instructions
     auto dpas = std::make_shared<DPASInstruction>();
@@ -140,31 +132,22 @@ struct BMGuArch : public Xe2Plus {
                 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
                 {/* register_file_info */}, // Optional: empty
                 {/* cache_info */},         // Optional: empty
-                {/* instructions */},       // Optional: empty
-                {/* restrictions */}        // Optional: empty
+                {/* instructions */}        // Optional: empty)
         ) {
-    // Initialize uArchHierarchy
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
-    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
     // Intialize register file info
     // GRF
-    this->register_file_info["GRF"] =
-        RegisterFileInfo(64 * 1024,          // size in bits
-                         {"small", "large"}, // GRF modes
-                         {128, 256},         // registers per thread per mode
-                         0,                  // number of banks
-                         0                   // bank size
-        );
+    this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
+        64 * 1024,                                          // size in bits
+        {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+        {128, 256} // registers per thread per mode
+    );
     // Initialize cache info
     // L1 cache, XeCore level
-    this->cache_info.push_back(
-        CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
-    // L3 cache, XeStack level
-    this->cache_info.push_back(
-        CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+    this->cacheInfo.push_back(
+        CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
+    // L2 cache, XeStack level
+    this->cacheInfo.push_back(
+        CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
 
     // Add the instructions
     auto dpas = std::make_shared<DPASInstruction>();
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 4ef02f32913f7..2416b173e505a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -27,19 +27,15 @@
 namespace mlir {
 namespace xegpu {
 namespace uArch {
-// Architecture HW component hierarchy to present thread, core, socket ...
-struct uArchHierarchyComponent {
-  std::string name = ""; // optional name of the hierarchy component
-  // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
-  // contains 8 threads, so no_of_component=8
-  uint32_t no_of_component;
-  // Constructor
-  uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
-      : name(name), no_of_component(no_of_component) {}
-};
 
 // An enum class to represent the scope of an instruction
-enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
+
+enum class InstructionName {
+  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
+        // operation
+  // Add more instructions as needed
+};
 
 // A struct to represent basic information about an instruction
 // This struct is used to represent the information about an instruction in the
@@ -67,69 +63,62 @@ struct Instruction {
   // Get methods
   std::string getName() { return name; }
   std::string getDescription() { return description; }
-  InstructionScopeEnum getScope() { return scope; }
+  InstructionScope getScope() { return scope; }
 
 protected:
   std::string name;
   std::string description;
-  InstructionScopeEnum scope;
+  InstructionScope scope;
 };
 
+enum class RegisterFileMode : uint8_t { Small, Large };
+enum class RegisterFileType : uint8_t { GRF, ARF };
+
 // A struct to represent register file information
 struct RegisterFileInfo {
   // Constructor
   RegisterFileInfo() = default;
-  RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
-                   const std::vector<uint32_t> &numRegs, uint32_t num_banks,
-                   uint32_t bank_size)
-      : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
-        num_banks(num_banks), bank_size(bank_size) {}
+  RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
+                   const std::vector<uint32_t> &numRegs)
+      : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
 
-  // Get methods
   uint32_t getSize() const { return size; }
-
-  const std::vector<std::string> &getModes() const { return mode; }
-
+  const std::vector<RegisterFileMode> &getModes() const { return mode; }
   const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
-    return num_regs_per_thread_per_mode;
+    return numRegsPerThreadPerMode;
   }
 
-  uint32_t getNumBanks() const { return num_banks; }
-
-  uint32_t getBankSize() const { return bank_size; }
-
 protected:
-  uint32_t size;                 // size per register in bits
-  std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+  uint32_t size;                      // size per register in bits
+  std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
   std::vector<uint32_t>
-      num_regs_per_thread_per_mode; // number of registers per thread per mode
-  uint32_t num_banks;
-  uint32_t bank_size;
+      numRegsPerThreadPerMode; // number of registers per thread per mode
+  // TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
+  // port_width, bank_conflicts)
 };
 
+enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
 // A struct to represent cache information
-
 struct CacheInfo {
   // Constructor
   CacheInfo(uint32_t size, uint32_t line_size,
-            const uArchHierarchyComponent &component)
-      : size(size), line_size(line_size), component(component) {}
+            CacheHierarchyLevel hierarchy_level)
+      : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
 
   virtual ~CacheInfo() = default;
 
   // Get methods
   uint32_t getSize() const { return size; }
   uint32_t getLineSize() const { return line_size; }
-  const uArchHierarchyComponent &getComponent() const { return component; }
+  CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; }
 
 protected:
   uint32_t size;
   uint32_t line_size;
-  // At which component level the cache is shared
-  uArchHierarchyComponent component;
-
+  CacheHierarchyLevel hierarchy_level;
   // @TODO: Add more fields as needed (e.g., associativity, num_banks,
-  // bank_size, num_ports, port_width, bank_conflicts)
+  // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
+  // latency, throughput, bandwidth)
 };
 
 // A struct to represent the uArch
@@ -145,13 +134,13 @@ struct uArch {
   // Constructor
   uArch() = default;
   uArch(const std::string &name, const std::string &description,
-        const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
-        const std::map<std::string, RegisterFileInfo> &register_file_info = {},
+        const std::map<RegisterFileType, RegisterFileInfo> &register_file_info =
+            {},
         const std::vector<CacheInfo> &cache_info = {},
         const std::map<std::string, std::shared_ptr<Instruction>>
             &instructions = {})
-      : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
-        register_file_info(register_file_info), cache_info(cache_info),
+      : name(name), description(description),
+        registerFileInfo(register_file_info), cacheInfo(cache_info),
         instructions(instructions) {}
 
   // Get methods
@@ -159,15 +148,12 @@ struct uArch {
 
   const std::string &getDescription() const { return description; }
 
-  const std::vector<uArchHierarchyComponent> &getHierarchy() const {
-    return uArch_hierarchy;
-  }
-
-  const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
-    return register_file_info;
+  const std::map<RegisterFileType, RegisterFileInfo> &
+  getRegisterFileInfo() const {
+    return registerFileInfo;
   }
 
-  const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+  const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
 
   const std::map<std::string, std::shared_ptr<Instruction>> &
   getInstructions() const {
@@ -192,9 +178,8 @@ struct uArch {
 protected:
   std::string name; // Similar to target triple
   std::string description;
-  std::vector<uArchHierarchyComponent> uArch_hierarchy;
-  std::map<std::string, RegisterFileInfo> register_file_info;
-  std::vector<CacheInfo> cache_info;
+  std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
+  std::vector<CacheInfo> cacheInfo;
   std::map<std::string, std::shared_ptr<Instruction>> instructions;
 };
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
index 27d44c38317a1..087313c357476 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -26,12 +26,11 @@ namespace mlir {
 namespace xegpu {
 namespace uArch {
 
-enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
 struct MMAInstructionInterface {
   // Get supported Matrix shapes
   virtual std::vector<std::pair<uint32_t, uint32_t>>
-  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
-
+  getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) = 0;
   // @TODO: This method takes an context object as a parameter, this is to
   // create the mlir::Type objects from the same context. Since type objects are
   // uniqued in a specific context, to do things like "aType == bType" (where
@@ -46,7 +45,7 @@ struct MMAInstructionInterface {
   // Untill we have a better solution, we stick to passing context object to
   // this method.
   virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
-                                                    MMAOpndEnum matrixType) = 0;
+                                                    MMAOpndKind matrixType) = 0;
   virtual bool
   checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
                                std::pair<uint32_t, uint32_t> BShape,
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
index 4db4300028c46..e19df1e5235ba 100644
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -16,7 +16,7 @@ namespace Xe2Plus {
 
 std::vector<std::pair<uint32_t, uint32_t>>
 DPASInstruction::getSupportedShapes(mlir::Type dataType,
-                                    MMAOpndEnum matrixType) {
+                                    MMAOpndKind matrixType) {
   auto combineVectors = [](const std::vector<uint32_t> &a,
                            const std::vector<uint32_t> &b)
       -> std::vector<std::pair<uint32_t, uint32_t>> {
@@ -35,16 +35,16 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
   std::vector<std::pair<unsigned, unsigned>> resultMatrix;
 
   switch (matrixType) {
-  case MMAOpndEnum::MatrixA:
+  case MMAOpndKind::MatrixA:
     resultMatrix = combineVectors(M, K);
     break;
-  case MMAOpndEnum::MatrixB:
+  case MMAOpndKind::MatrixB:
     resultMatrix = combineVectors(K, N);
     break;
-  case MMAOpndEnum::MatrixC:
+  case MMAOpndKind::MatrixC:
     resultMatrix = combineVectors(M, N);
     break;
-  case MMAOpndEnum::MatrixD:
+  case MMAOpndKind::MatrixD:
     resultMatrix = combineVectors(M, N);
     break;
   }
@@ -53,23 +53,23 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
 
 std::vector<mlir::Type>
 DPASInstruction::getSupportedTypes(MLIRContext &context,
-                                   MMAOpndEnum matrixType) {
+                                   MMAOpndKind matrixType) {
   mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
   mlir::Type f16Type = mlir::Float16Type::get(&context);
   mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
   mlir::Type f32Type = mlir::Float32Type::get(&context);
 
   switch (matrixType) {
-  case MMAOpndEnum::MatrixA:
+  case MMAOpndKind::MatrixA:
     return {bf16Type, f16Type, tf32Type};
     break;
-  case MMAOpndEnum::MatrixB:
+  case MMAOpndKind::MatrixB:
     return {bf16Type, f16Type, tf32Type};
     break;
-  case MMAOpndEnum::MatrixC:
+  case MMAOpndKind::MatrixC:
     return {bf16Type, f16Type, f32Type};
     break;
-  case MMAOpndEnum::MatrixD:
+  case MMAOpndKind::MatrixD:
     return {bf16Type, f16Type, f32Type};
     break;
   }
@@ -110,10 +110,10 @@ bool DPASInstruction::checkSupportedShapesAndTypes(
     std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
     std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
     mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
-  auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
-  auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
-  auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
-  auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
+  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
   return llvm::is_contained(supportedAShapes, AShape) &&
          llvm::is_contained(supportedBShapes, BShape) &&
          llvm::is_contained(supportedCShapes, CShape) &&

>From bba25dd7c6cf23042b8d19f19a8b8d420e03ce5d Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 17:43:59 +0000
Subject: [PATCH 04/13] Address review comments.

Simplify design:
  - Remove dialect initialization and necessary mechanism.
---
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 45 -------------------
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp    |  8 ----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  8 ----
 3 files changed, 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 2416b173e505a..ed9e842c56b62 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -199,51 +199,6 @@ struct SharedMemory {
   // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
 };
 
-struct uArchMap {
-public:
-  // Singleton instance
-  static uArchMap &instance() {
-    static uArchMap instance;
-    return instance;
-  }
-
-  // Insert or update a key-value pair
-  void insert(const std::string &key, std::shared_ptr<uArch> value) {
-    std::unique_lock<std::shared_mutex> lock(mutex_);
-    // map_[key] = std::move(value); // safe to overwrite
-    map_.emplace(key, std::move(value)); // safe to overwrite
-  }
-
-  // Get a value by key (concurrent safe read)
-  std::shared_ptr<uArch> get(const std::string &key) const {
-    std::shared_lock<std::shared_mutex> lock(mutex_);
-    auto it = map_.find(key);
-    if (it != map_.end())
-      return it->second;
-    return nullptr;
-  }
-
-  // Check if a key exists
-  bool contains(const std::string &key) const {
-    std::shared_lock<std::shared_mutex> lock(mutex_);
-    return map_.find(key) != map_.end();
-  }
-
-  // Remove a key
-  bool erase(const std::string &key) {
-    std::unique_lock<std::shared_mutex> lock(mutex_);
-    return map_.erase(key) > 0;
-  }
-
-private:
-  uArchMap() = default;
-  uArchMap(const uArchMap &) = delete;
-  uArchMap &operator=(const uArchMap &) = delete;
-
-  mutable std::shared_mutex mutex_;
-  std::map<std::string, std::shared_ptr<uArch>> map_;
-};
-
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 37ab1fcdd1c0e..6af69bb549bea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -381,14 +381,6 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
 }
 
 void XeVMDialect::initialize() {
-  // Populate the uArchMap with the supported target devices
-  auto pvcuArch =
-      std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
-  mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
-  auto bmguArch =
-      std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
-  mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
-
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 359517432d82e..9beb22d517473 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -36,14 +36,6 @@ void XeGPUDialect::initialize() {
 #define GET_ATTRDEF_LIST
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
       >();
-
-  // Populate the uArchMap with the supported target devices
-  auto pvcuArch =
-      std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
-  mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
-  auto bmguArch =
-      std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
-  mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
 }
 
 /// Generates instructions to compute offsets for a subgroup identified by

>From b0e6f3400e265c35491c7b4931c9e34191cffde6 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 23:22:18 +0000
Subject: [PATCH 05/13] Address review comments.

Move all the implementation to the .h file.
Move uArchInterfaces to uArchBase.
---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 192 ++++++++++++++++--
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |  43 ++++
 .../Dialect/XeGPU/uArch/uArchInterfaces.h     |  74 -------
 mlir/lib/Dialect/LLVMIR/CMakeLists.txt        |   1 -
 mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp    |   1 -
 mlir/lib/Dialect/XeGPU/CMakeLists.txt         |   1 -
 mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt      |   1 -
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 -
 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt   |  11 -
 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp  | 172 ----------------
 10 files changed, 213 insertions(+), 284 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
 delete mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
 delete mode 100644 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index a7ae0bdf4378b..7b1381e9efc2d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -15,17 +15,22 @@
 #ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
 #define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
 
-#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/Support/DebugLog.h"
 #include <map>
 #include <string>
 #include <vector>
 
+#define DEBUG_TYPE "xegpu-uarch"
+
+using namespace mlir;
+using namespace mlir::xegpu::uArch;
+
 namespace mlir {
 namespace xegpu {
 namespace uArch {
-namespace Xe2Plus {
 struct XeCoreInfo {
   uint32_t num_threads;
   SharedMemory shared_memory;
@@ -61,30 +66,27 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
 
   // Override all virtuals from MatrixOpInterface
   virtual std::vector<std::pair<uint32_t, uint32_t>>
-  getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
-  virtual std::vector<mlir::Type>
-  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
+  getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
+  virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
+                                              MMAOpndKind matrixType) override;
   virtual bool
   checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
                                std::pair<uint32_t, uint32_t> BShape,
                                std::pair<uint32_t, uint32_t> CShape,
-                               std::pair<uint32_t, uint32_t> DShape,
-                               mlir::Type AType, mlir::Type BType,
-                               mlir::Type CType, mlir::Type DType) override;
-  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
-                                   mlir::Type CType, mlir::Type DType) override;
+                               std::pair<uint32_t, uint32_t> DShape, Type AType,
+                               Type BType, Type CType, Type DType) override;
+  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+                                   Type DType) override;
   virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
                         std::pair<uint32_t, uint32_t> BShape,
                         std::pair<uint32_t, uint32_t> CShape,
-                        std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
-                        mlir::Type BType, mlir::Type CType,
-                        mlir::Type DType) override;
-  virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
-  virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
-  virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+                        std::pair<uint32_t, uint32_t> DShape, Type AType,
+                        Type BType, Type CType, Type DType) override;
+  virtual std::vector<uint32_t> getSupportedM(Type type) override;
+  virtual std::vector<uint32_t> getSupportedK(Type type) override;
+  virtual std::vector<uint32_t> getSupportedN(Type type) override;
 };
 
-namespace PVCuArch {
 struct PVCuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
   std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -120,9 +122,7 @@ struct PVCuArch : public Xe2Plus {
     owned_instructions.push_back(dpas);
   }
 };
-} // namespace PVCuArch
 
-namespace BMGuArch {
 struct BMGuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
   std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -156,11 +156,159 @@ struct BMGuArch : public Xe2Plus {
     owned_instructions.push_back(dpas);
   }
 };
-} // namespace BMGuArch
-
-} // namespace Xe2Plus
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
 
+inline std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+  auto combineVectors = [](const std::vector<uint32_t> &a,
+                           const std::vector<uint32_t> &b)
+      -> std::vector<std::pair<uint32_t, uint32_t>> {
+    std::vector<std::pair<uint32_t, uint32_t>> result;
+    for (unsigned x : a) {
+      for (unsigned y : b) {
+        result.emplace_back(x, y);
+      }
+    }
+    return result;
+  };
+
+  auto M = getSupportedM(dataType);
+  auto K = getSupportedK(dataType);
+  auto N = getSupportedN(dataType);
+  std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    resultMatrix = combineVectors(M, K);
+    break;
+  case MMAOpndKind::MatrixB:
+    resultMatrix = combineVectors(K, N);
+    break;
+  case MMAOpndKind::MatrixC:
+    resultMatrix = combineVectors(M, N);
+    break;
+  case MMAOpndKind::MatrixD:
+    resultMatrix = combineVectors(M, N);
+    break;
+  }
+  return resultMatrix;
+}
+
+inline std::vector<Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+                                   MMAOpndKind matrixType) {
+  Type bf16Type = BFloat16Type::get(&context);
+  Type f16Type = Float16Type::get(&context);
+  Type tf32Type = FloatTF32Type::get(&context);
+  Type f32Type = Float32Type::get(&context);
+
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndKind::MatrixB:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndKind::MatrixC:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  case MMAOpndKind::MatrixD:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  }
+}
+
+inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
+                                                 Type CType, Type DType) {
+  if (AType.isF16() || BType.isF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+        (!DType.isF32() && !DType.isF16())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (AType.isBF16() || BType.isBF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+        (!DType.isF32() && !DType.isBF16())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (AType.isTF32() || BType.isTF32()) {
+    if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+        (!DType.isF32())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+               AType.isInteger(8)) &&
+             !(BType.isInteger(2) || BType.isInteger(4) ||
+               BType.isInteger(8))) {
+    LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+    return false;
+  }
+
+  return true;
+}
+
+inline bool DPASInstruction::checkSupportedShapesAndTypes(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    Type AType, Type BType, Type CType, Type DType) {
+  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
+  return llvm::is_contained(supportedAShapes, AShape) &&
+         llvm::is_contained(supportedBShapes, BShape) &&
+         llvm::is_contained(supportedCShapes, CShape) &&
+         llvm::is_contained(supportedDShapes, DShape) &&
+         checkSupportedTypes(AType, BType, CType, DType);
+}
+
+inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+                                      std::pair<uint32_t, uint32_t> BShape,
+                                      std::pair<uint32_t, uint32_t> CShape,
+                                      std::pair<uint32_t, uint32_t> DShape,
+                                      Type AType, Type BType, Type CType,
+                                      Type DType) {
+  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+                                      BType, CType, DType);
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
+  return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
+  // assert if data type is not int or float type
+  assert(type.isIntOrFloat() && "Matrix type must be int or float");
+  auto bitWidth = type.getIntOrFloatBitWidth();
+  uint32_t kSize = 0;
+  switch (bitWidth) {
+  case 2:
+    kSize = 64;
+    break;
+  case 4:
+    kSize = 64;
+    break;
+  case 8:
+    kSize = 32;
+    break;
+  case 16:
+    kSize = 16;
+    break;
+  case 32:
+    kSize = 8;
+    break;
+  default:
+    llvm_unreachable("Invalid int or float");
+  }
+  return {kSize};
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
+  return {16};
+}
+
 #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index ed9e842c56b62..7c76157fa2a1a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -199,6 +199,49 @@ struct SharedMemory {
   // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
 };
 
+//===----------------------------------------------------------------------===//
+// Interfaces
+//===----------------------------------------------------------------------===//
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+  // Get supported Matrix shapes
+  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
+  // @TODO: This method takes an context object as a parameter, this is to
+  // create the Type objects from the same context. Since type objects are
+  // uniqued in a specific context, to do things like "aType == bType" (where
+  // aType and bType are both same type) kind of checks, the both types should
+  // be from the same context.
+  //
+  // One alternative to this is to create enum to represent each types, but this
+  // adds an extra burden to user to convert these enums to specific types. In
+  // fact the utility that would convert enumToType() and vice versa would still
+  // have to use the context object.
+  //
+  // Untill we have a better solution, we stick to passing context object to
+  // this method.
+  virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
+                                              MMAOpndKind matrixType) = 0;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape, Type AType,
+                               Type BType, Type CType, Type DType) = 0;
+  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+                                   Type DType) = 0;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, Type AType,
+                        Type BType, Type CType, Type DType) = 0;
+  virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
+  virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
+  virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
+
+  virtual ~MMAInstructionInterface() = default;
+};
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
deleted file mode 100644
index 087313c357476..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
+++ /dev/null
@@ -1,74 +0,0 @@
-//===--- uArchInterfaces.h ---*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-/// \file
-/// Defines the utility interfaces that are implemented by individual
-/// instructions.
-///
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
-#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
-
-#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include <map>
-#include <string>
-#include <vector>
-
-namespace mlir {
-namespace xegpu {
-namespace uArch {
-
-enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
-struct MMAInstructionInterface {
-  // Get supported Matrix shapes
-  virtual std::vector<std::pair<uint32_t, uint32_t>>
-  getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) = 0;
-  // @TODO: This method takes an context object as a parameter, this is to
-  // create the mlir::Type objects from the same context. Since type objects are
-  // uniqued in a specific context, to do things like "aType == bType" (where
-  // aType and bType are both same type) kind of checks, the both types should
-  // be from the same context.
-  //
-  // One alternative to this is to create enum to represent each types, but this
-  // adds an extra burden to user to convert these enums to specific types. In
-  // fact the utility that would convert enumToType() and vice versa would still
-  // have to use the context object.
-  //
-  // Untill we have a better solution, we stick to passing context object to
-  // this method.
-  virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
-                                                    MMAOpndKind matrixType) = 0;
-  virtual bool
-  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
-                               std::pair<uint32_t, uint32_t> BShape,
-                               std::pair<uint32_t, uint32_t> CShape,
-                               std::pair<uint32_t, uint32_t> DShape,
-                               mlir::Type AType, mlir::Type BType,
-                               mlir::Type CType, mlir::Type DType) = 0;
-  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
-                                   mlir::Type CType, mlir::Type DType) = 0;
-  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
-                        std::pair<uint32_t, uint32_t> BShape,
-                        std::pair<uint32_t, uint32_t> CShape,
-                        std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
-                        mlir::Type BType, mlir::Type CType,
-                        mlir::Type DType) = 0;
-  virtual std::vector<uint32_t> getSupportedM(mlir::Type type) = 0;
-  virtual std::vector<uint32_t> getSupportedK(mlir::Type type) = 0;
-  virtual std::vector<uint32_t> getSupportedN(mlir::Type type) = 0;
-
-  virtual ~MMAInstructionInterface() = default;
-};
-
-} // namespace uArch
-} // namespace xegpu
-} // namespace mlir
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 9e40d41de1c73..ec581ac7277e3 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -129,6 +129,5 @@ add_mlir_dialect_library(MLIRXeVMDialect
   MLIRDialectUtils
   MLIRIR
   MLIRLLVMDialect
-  MLIRXeGPUuArch
   MLIRSideEffectInterfaces
 )
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 6af69bb549bea..04e8836c00359 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -8,7 +8,6 @@
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 9079df050ab2b..31167e6af908b 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,4 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
-add_subdirectory(uArch)
 add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index e1c51deefbee3..7869a28dfed57 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -20,7 +20,6 @@ add_mlir_dialect_library(MLIRXeGPUDialect
   MLIRGPUDialect
   MLIRXeVMDialect
   MLIRIR
-  MLIRXeGPUuArch
   MLIRViewLikeInterface
   MLIRVectorDialect
 )
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e56fe3a1e193d..e6f76067094ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -23,7 +23,6 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   MLIRTransforms
   MLIRGPUDialect
   MLIRXeGPUUtils
-  MLIRXeGPUuArch
   MLIRGPUUtils
   MLIRVectorTransforms
 )
diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
deleted file mode 100644
index c7f691cb6dda7..0000000000000
--- a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_dialect_library(MLIRXeGPUuArch
-  IntelGpuXe2.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRDialectUtils
-)
-
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
deleted file mode 100644
index e19df1e5235ba..0000000000000
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ /dev/null
@@ -1,172 +0,0 @@
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "llvm/Support/DebugLog.h"
-#include <algorithm>
-#include <vector>
-
-#define DEBUG_TYPE "xegpu-uarch"
-
-using namespace mlir::xegpu::uArch;
-using namespace mlir::xegpu::uArch::Xe2Plus;
-
-namespace mlir {
-namespace xegpu {
-namespace uArch {
-namespace Xe2Plus {
-
-std::vector<std::pair<uint32_t, uint32_t>>
-DPASInstruction::getSupportedShapes(mlir::Type dataType,
-                                    MMAOpndKind matrixType) {
-  auto combineVectors = [](const std::vector<uint32_t> &a,
-                           const std::vector<uint32_t> &b)
-      -> std::vector<std::pair<uint32_t, uint32_t>> {
-    std::vector<std::pair<uint32_t, uint32_t>> result;
-    for (unsigned x : a) {
-      for (unsigned y : b) {
-        result.emplace_back(x, y);
-      }
-    }
-    return result;
-  };
-
-  auto M = getSupportedM(dataType);
-  auto K = getSupportedK(dataType);
-  auto N = getSupportedN(dataType);
-  std::vector<std::pair<unsigned, unsigned>> resultMatrix;
-
-  switch (matrixType) {
-  case MMAOpndKind::MatrixA:
-    resultMatrix = combineVectors(M, K);
-    break;
-  case MMAOpndKind::MatrixB:
-    resultMatrix = combineVectors(K, N);
-    break;
-  case MMAOpndKind::MatrixC:
-    resultMatrix = combineVectors(M, N);
-    break;
-  case MMAOpndKind::MatrixD:
-    resultMatrix = combineVectors(M, N);
-    break;
-  }
-  return resultMatrix;
-}
-
-std::vector<mlir::Type>
-DPASInstruction::getSupportedTypes(MLIRContext &context,
-                                   MMAOpndKind matrixType) {
-  mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
-  mlir::Type f16Type = mlir::Float16Type::get(&context);
-  mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
-  mlir::Type f32Type = mlir::Float32Type::get(&context);
-
-  switch (matrixType) {
-  case MMAOpndKind::MatrixA:
-    return {bf16Type, f16Type, tf32Type};
-    break;
-  case MMAOpndKind::MatrixB:
-    return {bf16Type, f16Type, tf32Type};
-    break;
-  case MMAOpndKind::MatrixC:
-    return {bf16Type, f16Type, f32Type};
-    break;
-  case MMAOpndKind::MatrixD:
-    return {bf16Type, f16Type, f32Type};
-    break;
-  }
-}
-
-bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
-                                          mlir::Type CType, mlir::Type DType) {
-  if (AType.isF16() || BType.isF16()) {
-    if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
-        (!DType.isF32() && !DType.isF16())) {
-      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
-      return false;
-    }
-  } else if (AType.isBF16() || BType.isBF16()) {
-    if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
-        (!DType.isF32() && !DType.isBF16())) {
-      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
-      return false;
-    }
-  } else if (AType.isTF32() || BType.isTF32()) {
-    if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
-        (!DType.isF32())) {
-      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
-      return false;
-    }
-  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
-               AType.isInteger(8)) &&
-             !(BType.isInteger(2) || BType.isInteger(4) ||
-               BType.isInteger(8))) {
-    LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
-    return false;
-  }
-
-  return true;
-}
-
-bool DPASInstruction::checkSupportedShapesAndTypes(
-    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
-    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
-    mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
-  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
-  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
-  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
-  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
-  return llvm::is_contained(supportedAShapes, AShape) &&
-         llvm::is_contained(supportedBShapes, BShape) &&
-         llvm::is_contained(supportedCShapes, CShape) &&
-         llvm::is_contained(supportedDShapes, DShape) &&
-         checkSupportedTypes(AType, BType, CType, DType);
-}
-
-bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
-                               std::pair<uint32_t, uint32_t> BShape,
-                               std::pair<uint32_t, uint32_t> CShape,
-                               std::pair<uint32_t, uint32_t> DShape,
-                               mlir::Type AType, mlir::Type BType,
-                               mlir::Type CType, mlir::Type DType) {
-  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
-                                      BType, CType, DType);
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedM(mlir::Type type) {
-  return {1, 2, 3, 4, 5, 6, 7, 8};
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedK(mlir::Type type) {
-  // assert if data type is not int or float type
-  assert(type.isIntOrFloat() && "Matrix type must be int or float");
-  auto bitWidth = type.getIntOrFloatBitWidth();
-  uint32_t kSize = 0;
-  switch (bitWidth) {
-  case 2:
-    kSize = 64;
-    break;
-  case 4:
-    kSize = 64;
-    break;
-  case 8:
-    kSize = 32;
-    break;
-  case 16:
-    kSize = 16;
-    break;
-  case 32:
-    kSize = 8;
-    break;
-  default:
-    llvm_unreachable("Invalid int or float");
-  }
-  return {kSize};
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedN(mlir::Type type) {
-  return {16};
-}
-
-} // namespace Xe2Plus
-} // namespace uArch
-} // namespace xegpu
-} // namespace mlir

>From 6098788ed6671769487026d38dbaf8ebd6ff6a3e Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 15:49:17 +0000
Subject: [PATCH 06/13] Address review comments.

Use LLVM data structures whenever possible.
---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 73 ++++++++--------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 85 ++++++++++++-------
 2 files changed, 88 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 7b1381e9efc2d..453dc57a72020 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -18,10 +18,10 @@
 #include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/DebugLog.h"
 #include <map>
 #include <string>
-#include <vector>
 
 #define DEBUG_TYPE "xegpu-uarch"
 
@@ -47,12 +47,12 @@ struct XeCoreInfo {
 
 struct Xe2Plus : public uArch {
   XeCoreInfo xe_core;
-  Xe2Plus(
-      const std::string &archName, const std::string &archDescription,
-      const XeCoreInfo &xeCore,
-      const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
-      const std::vector<CacheInfo> &cacheInfo = {},
-      const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+  Xe2Plus(const std::string &archName, const std::string &archDescription,
+          const XeCoreInfo &xeCore,
+          const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
+          const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+          const std::map<InstructionKind, std::shared_ptr<Instruction>>
+              &instrs = {})
       : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
         xe_core(xeCore) {}
 };
@@ -60,15 +60,16 @@ struct Xe2Plus : public uArch {
 // struct to represent DPAS instruction
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   DPASInstruction()
-      : Instruction("dpas",                   // name
-                    "Dot Product Accumulate") // description
+      : Instruction(InstructionKind::DPAS, // name
+                    "Dot Product Accumulate",
+                    InstructionScope::Subgroup) // description
   {}
 
   // Override all virtuals from MatrixOpInterface
-  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
   getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
-  virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
-                                              MMAOpndKind matrixType) override;
+  virtual llvm::SmallVector<Type, 8>
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
   virtual bool
   checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
                                std::pair<uint32_t, uint32_t> BShape,
@@ -82,14 +83,14 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
                         std::pair<uint32_t, uint32_t> CShape,
                         std::pair<uint32_t, uint32_t> DShape, Type AType,
                         Type BType, Type CType, Type DType) override;
-  virtual std::vector<uint32_t> getSupportedM(Type type) override;
-  virtual std::vector<uint32_t> getSupportedK(Type type) override;
-  virtual std::vector<uint32_t> getSupportedN(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
 };
 
 struct PVCuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
-  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
   PVCuArch()
       : Xe2Plus("pvc",                        // archName
                 "Ponte Vecchio Architecture", // archDescription
@@ -115,17 +116,16 @@ struct PVCuArch : public Xe2Plus {
     this->cacheInfo.push_back(
         CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
 
-    // Add the instructions
+    // Add the instructions-
     auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getName(), dpas);
-    // instructions[dpas->name] = dpas.get();
+    instructions.emplace(dpas->getInstructionKind(), dpas);
     owned_instructions.push_back(dpas);
   }
 };
 
 struct BMGuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
-  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
   BMGuArch()
       : Xe2Plus("bmg",                     // archName
                 "Battlemage Architecture", // archDescription
@@ -151,8 +151,7 @@ struct BMGuArch : public Xe2Plus {
 
     // Add the instructions
     auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getName(), dpas);
-    // instructions[dpas->name] = dpas.get();
+    instructions.emplace(dpas->getInstructionKind(), dpas);
     owned_instructions.push_back(dpas);
   }
 };
@@ -160,12 +159,12 @@ struct BMGuArch : public Xe2Plus {
 } // namespace xegpu
 } // namespace mlir
 
-inline std::vector<std::pair<uint32_t, uint32_t>>
+inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
 DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
-  auto combineVectors = [](const std::vector<uint32_t> &a,
-                           const std::vector<uint32_t> &b)
-      -> std::vector<std::pair<uint32_t, uint32_t>> {
-    std::vector<std::pair<uint32_t, uint32_t>> result;
+  auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
+                           const llvm::SmallVector<uint32_t, 8> &b)
+      -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
+    llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
     for (unsigned x : a) {
       for (unsigned y : b) {
         result.emplace_back(x, y);
@@ -177,7 +176,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   auto M = getSupportedM(dataType);
   auto K = getSupportedK(dataType);
   auto N = getSupportedN(dataType);
-  std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+  llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
 
   switch (matrixType) {
   case MMAOpndKind::MatrixA:
@@ -196,7 +195,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   return resultMatrix;
 }
 
-inline std::vector<Type>
+inline llvm::SmallVector<Type, 8>
 DPASInstruction::getSupportedTypes(MLIRContext &context,
                                    MMAOpndKind matrixType) {
   Type bf16Type = BFloat16Type::get(&context);
@@ -207,17 +206,14 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
   switch (matrixType) {
   case MMAOpndKind::MatrixA:
     return {bf16Type, f16Type, tf32Type};
-    break;
   case MMAOpndKind::MatrixB:
     return {bf16Type, f16Type, tf32Type};
-    break;
   case MMAOpndKind::MatrixC:
     return {bf16Type, f16Type, f32Type};
-    break;
   case MMAOpndKind::MatrixD:
     return {bf16Type, f16Type, f32Type};
-    break;
   }
+  return {};
 }
 
 inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
@@ -276,11 +272,13 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
                                       BType, CType, DType);
 }
 
-inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedM(Type type) {
   return {1, 2, 3, 4, 5, 6, 7, 8};
 }
 
-inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedK(Type type) {
   // assert if data type is not int or float type
   assert(type.isIntOrFloat() && "Matrix type must be int or float");
   auto bitWidth = type.getIntOrFloatBitWidth();
@@ -307,8 +305,9 @@ inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
   return {kSize};
 }
 
-inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedN(Type type) {
   return {16};
 }
 
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 7c76157fa2a1a..fe0e0f6528042 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -23,6 +23,7 @@
 #include <tuple>
 
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 namespace xegpu {
@@ -31,12 +32,26 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
 
-enum class InstructionName {
-  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
-        // operation
+enum class InstructionKind {
+  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+        // multiply-add operation
   // Add more instructions as needed
 };
 
+llvm::StringRef toString(InstructionKind name) {
+  switch (name) {
+  case InstructionKind::DPAS:
+    return "dpas";
+  }
+  llvm_unreachable("Unknown InstructionKind");
+}
+
+std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
+  if (str.equals_insensitive("dpas"))
+    return InstructionKind::DPAS;
+  return std::nullopt;
+}
+
 // A struct to represent basic information about an instruction
 // This struct is used to represent the information about an instruction in the
 // uArch The information includes:
@@ -56,17 +71,17 @@ enum class InstructionName {
 
 struct Instruction {
   // @TODO: Add more fields as needed
-  Instruction(std::string name, std::string desc)
-      : name(std::move(name)), description(std::move(desc)) {}
+  Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
+      : instKind(kind), description(std::move(desc)), scope(scope) {}
 
   virtual ~Instruction() = default;
   // Get methods
-  std::string getName() { return name; }
+  InstructionKind getInstructionKind() { return instKind; }
   std::string getDescription() { return description; }
   InstructionScope getScope() { return scope; }
 
 protected:
-  std::string name;
+  InstructionKind instKind;
   std::string description;
   InstructionScope scope;
 };
@@ -78,23 +93,25 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
 struct RegisterFileInfo {
   // Constructor
   RegisterFileInfo() = default;
-  RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
-                   const std::vector<uint32_t> &numRegs)
+  RegisterFileInfo(uint32_t size,
+                   const llvm::SmallVector<RegisterFileMode, 4> &mode,
+                   const llvm::SmallVector<uint32_t, 4> &numRegs)
       : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
 
   uint32_t getSize() const { return size; }
-  const std::vector<RegisterFileMode> &getModes() const { return mode; }
-  const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+  const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
+    return mode;
+  }
+  const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
     return numRegsPerThreadPerMode;
   }
 
 protected:
-  uint32_t size;                      // size per register in bits
-  std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
-  std::vector<uint32_t>
+  uint32_t size; // size per register in bits
+  llvm::SmallVector<RegisterFileMode, 4>
+      mode; // e.g., "small", "large" GRF modes
+  llvm::SmallVector<uint32_t, 4>
       numRegsPerThreadPerMode; // number of registers per thread per mode
-  // TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
-  // port_width, bank_conflicts)
 };
 
 enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
@@ -136,8 +153,8 @@ struct uArch {
   uArch(const std::string &name, const std::string &description,
         const std::map<RegisterFileType, RegisterFileInfo> &register_file_info =
             {},
-        const std::vector<CacheInfo> &cache_info = {},
-        const std::map<std::string, std::shared_ptr<Instruction>>
+        const llvm::SmallVector<CacheInfo, 4> &cache_info = {},
+        const std::map<InstructionKind, std::shared_ptr<Instruction>>
             &instructions = {})
       : name(name), description(description),
         registerFileInfo(register_file_info), cacheInfo(cache_info),
@@ -153,34 +170,36 @@ struct uArch {
     return registerFileInfo;
   }
 
-  const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
+  const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
+    return cacheInfo;
+  }
 
-  const std::map<std::string, std::shared_ptr<Instruction>> &
+  const std::map<InstructionKind, std::shared_ptr<Instruction>> &
   getInstructions() const {
     return instructions;
   }
 
   // Get the name of the supported instruction names for that
   // architecture. It returns the names of the instructions added to the uArch.
-  std::vector<std::string> getSupportedInstructionNames() const {
-    std::vector<std::string> instructionNames;
+  llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
+    llvm::SmallVector<StringRef, 8> instructionNames;
     for (const auto &inst : instructions) {
-      instructionNames.push_back(inst.first);
+      instructionNames.push_back(toString(inst.first));
     }
     return instructionNames;
   }
 
   // Checks if an instruction is supported in this uArch
-  bool checkSupportedInstruction(const std::string &instructionName) const {
-    return instructions.find(instructionName) != instructions.end();
+  bool checkSupportedInstruction(InstructionKind instr) const {
+    return instructions.find(instr) != instructions.end();
   }
 
 protected:
   std::string name; // Similar to target triple
   std::string description;
   std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
-  std::vector<CacheInfo> cacheInfo;
-  std::map<std::string, std::shared_ptr<Instruction>> instructions;
+  llvm::SmallVector<CacheInfo, 4> cacheInfo;
+  std::map<InstructionKind, std::shared_ptr<Instruction>> instructions;
 };
 
 // A struct to represent shared memory information
@@ -205,7 +224,7 @@ struct SharedMemory {
 enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
 struct MMAInstructionInterface {
   // Get supported Matrix shapes
-  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
   getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
   // @TODO: This method takes an context object as a parameter, this is to
   // create the Type objects from the same context. Since type objects are
@@ -220,8 +239,8 @@ struct MMAInstructionInterface {
   //
   // Untill we have a better solution, we stick to passing context object to
   // this method.
-  virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
-                                              MMAOpndKind matrixType) = 0;
+  virtual llvm::SmallVector<Type, 8>
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
   virtual bool
   checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
                                std::pair<uint32_t, uint32_t> BShape,
@@ -235,9 +254,9 @@ struct MMAInstructionInterface {
                         std::pair<uint32_t, uint32_t> CShape,
                         std::pair<uint32_t, uint32_t> DShape, Type AType,
                         Type BType, Type CType, Type DType) = 0;
-  virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
-  virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
-  virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
 
   virtual ~MMAInstructionInterface() = default;
 };

>From 8fb4f1794d309b8d9385d2fb468148a38f6161a0 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 15:58:40 +0000
Subject: [PATCH 07/13] Add/Remove some spacings.

---
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index fe0e0f6528042..8f38b2afede4d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -31,7 +31,6 @@ namespace uArch {
 
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
-
 enum class InstructionKind {
   DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
         // multiply-add operation
@@ -68,7 +67,6 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
 // represent information about an instruction and to use this information to
 // generate the uArch. Specifc instruction in a uArch can inherit from this
 // struct and add more fields as needed
-
 struct Instruction {
   // @TODO: Add more fields as needed
   Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
@@ -115,6 +113,7 @@ struct RegisterFileInfo {
 };
 
 enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
+
 // A struct to represent cache information
 struct CacheInfo {
   // Constructor

>From 31b93d6c2daee68aab6011fa56ff657b500c2e4a Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 16:38:48 +0000
Subject: [PATCH 08/13] Address review comments.

---
 mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 6 +-----
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h   | 7 +++----
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 453dc57a72020..19c9138055b27 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -37,7 +37,6 @@ struct XeCoreInfo {
   uint32_t num_vector_units;
   uint32_t num_matrix_units;
 
-  // Constructor
   XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
              uint32_t num_vector_units, uint32_t num_matrix_units)
       : num_threads(num_threads), shared_memory(shared_memory),
@@ -60,10 +59,7 @@ struct Xe2Plus : public uArch {
 // struct to represent DPAS instruction
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   DPASInstruction()
-      : Instruction(InstructionKind::DPAS, // name
-                    "Dot Product Accumulate",
-                    InstructionScope::Subgroup) // description
-  {}
+      : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
 
   // Override all virtuals from MatrixOpInterface
   virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f38b2afede4d..69db5d5b5367f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -69,13 +69,12 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
 // struct and add more fields as needed
 struct Instruction {
   // @TODO: Add more fields as needed
-  Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
-      : instKind(kind), description(std::move(desc)), scope(scope) {}
+  Instruction(InstructionKind kind, InstructionScope scope)
+      : instKind(kind), scope(scope) {}
 
   virtual ~Instruction() = default;
   // Get methods
   InstructionKind getInstructionKind() { return instKind; }
-  std::string getDescription() { return description; }
   InstructionScope getScope() { return scope; }
 
 protected:
@@ -143,7 +142,7 @@ struct CacheInfo {
 // - the name of the uArch,
 // - the description of the uArch,
 // - uArch hierarchy
-// - Rgister File information
+// - Register File information
 // - Cache information
 // - the set of instructions supported by the uArch,
 struct uArch {

>From fa6c4bcfb31746dac05a15be0fb7071e15585639 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:09:33 +0000
Subject: [PATCH 09/13] Address review comments.

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 28 +++-----
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 70 +++++++++----------
 2 files changed, 41 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 19c9138055b27..0519f7b2e277d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -31,21 +31,9 @@ using namespace mlir::xegpu::uArch;
 namespace mlir {
 namespace xegpu {
 namespace uArch {
-struct XeCoreInfo {
-  uint32_t num_threads;
-  SharedMemory shared_memory;
-  uint32_t num_vector_units;
-  uint32_t num_matrix_units;
-
-  XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
-             uint32_t num_vector_units, uint32_t num_matrix_units)
-      : num_threads(num_threads), shared_memory(shared_memory),
-        num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
-  }
-};
 
 struct Xe2Plus : public uArch {
-  XeCoreInfo xe_core;
+  XeCoreInfo xeCore;
   Xe2Plus(const std::string &archName, const std::string &archDescription,
           const XeCoreInfo &xeCore,
           const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
@@ -53,7 +41,7 @@ struct Xe2Plus : public uArch {
           const std::map<InstructionKind, std::shared_ptr<Instruction>>
               &instrs = {})
       : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
-        xe_core(xeCore) {}
+        xeCore(xeCore) {}
 };
 
 // struct to represent DPAS instruction
@@ -91,9 +79,9 @@ struct PVCuArch : public Xe2Plus {
       : Xe2Plus("pvc",                        // archName
                 "Ponte Vecchio Architecture", // archDescription
                 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
-                {/* register_file_info */}, // Optional: empty
-                {/* cache_info */},         // Optional: empty
-                {/* instructions */}        // Optional: empty
+                {/* registerFileInfo */}, // Optional: empty
+                {/* cacheInfo */},        // Optional: empty
+                {/* instructions */}      // Optional: empty
         ) {
     // Intialize register file info
     // GRF
@@ -126,9 +114,9 @@ struct BMGuArch : public Xe2Plus {
       : Xe2Plus("bmg",                     // archName
                 "Battlemage Architecture", // archDescription
                 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
-                {/* register_file_info */}, // Optional: empty
-                {/* cache_info */},         // Optional: empty
-                {/* instructions */}        // Optional: empty)
+                {/* registerFileInfo */}, // Optional: empty
+                {/* cacheInfo */},        // Optional: empty
+                {/* instructions */}      // Optional: empty
         ) {
     // Intialize register file info
     // GRF
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 69db5d5b5367f..bb47a3e138e7b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -1,4 +1,4 @@
-//===--- uArch.h ------------------------------------------------*- C++ -*-===//
+//===- uArch.h --------------------------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -34,7 +34,7 @@ enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
   DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
         // multiply-add operation
-  // Add more instructions as needed
+  // @TODO: Add more instructions as needed
 };
 
 llvm::StringRef toString(InstructionKind name) {
@@ -51,24 +51,12 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
   return std::nullopt;
 }
 
-// A struct to represent basic information about an instruction
-// This struct is used to represent the information about an instruction in the
-// uArch The information includes:
-// - the name of the instruction,
-// - the description of the instruction
-// - the scope of the instruction,
-//
-// The information is represented as strings
-// For example, the information about an instruction can be represented as:
-// Instruction instr = {"dpas", "Dot Product Accumulate Systolic  (DPAS) is a
-// matrix multiply-add operation", "subgroup"};
-
+// A struct to represent basic information about an instruction.
 // The primary purpose of the Instruction struct is to provide a generic way to
 // represent information about an instruction and to use this information to
 // generate the uArch. Specifc instruction in a uArch can inherit from this
-// struct and add more fields as needed
+// struct and add more fields as needed.
 struct Instruction {
-  // @TODO: Add more fields as needed
   Instruction(InstructionKind kind, InstructionScope scope)
       : instKind(kind), scope(scope) {}
 
@@ -79,8 +67,8 @@ struct Instruction {
 
 protected:
   InstructionKind instKind;
-  std::string description;
   InstructionScope scope;
+  // @TODO: Add more fields as needed
 };
 
 enum class RegisterFileMode : uint8_t { Small, Large };
@@ -89,16 +77,18 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
 // A struct to represent register file information
 struct RegisterFileInfo {
   // Constructor
-  RegisterFileInfo() = default;
   RegisterFileInfo(uint32_t size,
                    const llvm::SmallVector<RegisterFileMode, 4> &mode,
                    const llvm::SmallVector<uint32_t, 4> &numRegs)
       : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
 
+  // Get methods
   uint32_t getSize() const { return size; }
+
   const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
     return mode;
   }
+
   const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
     return numRegsPerThreadPerMode;
   }
@@ -137,25 +127,17 @@ struct CacheInfo {
 };
 
 // A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device
-// The uArch includes:
-// - the name of the uArch,
-// - the description of the uArch,
-// - uArch hierarchy
-// - Register File information
-// - Cache information
-// - the set of instructions supported by the uArch,
+// This struct is used to represent the microarchitecture of a target device.
 struct uArch {
   // Constructor
-  uArch() = default;
-  uArch(const std::string &name, const std::string &description,
-        const std::map<RegisterFileType, RegisterFileInfo> &register_file_info =
-            {},
-        const llvm::SmallVector<CacheInfo, 4> &cache_info = {},
-        const std::map<InstructionKind, std::shared_ptr<Instruction>>
-            &instructions = {})
+  uArch(
+      const std::string &name, const std::string &description,
+      const std::map<RegisterFileType, RegisterFileInfo> &registerFileInfo = {},
+      const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+      const std::map<InstructionKind, std::shared_ptr<Instruction>>
+          &instructions = {})
       : name(name), description(description),
-        registerFileInfo(register_file_info), cacheInfo(cache_info),
+        registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
         instructions(instructions) {}
 
   // Get methods
@@ -193,11 +175,12 @@ struct uArch {
   }
 
 protected:
-  std::string name; // Similar to target triple
+  std::string name; // Name of the uArch, similar to target triple
   std::string description;
   std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
   llvm::SmallVector<CacheInfo, 4> cacheInfo;
-  std::map<InstructionKind, std::shared_ptr<Instruction>> instructions;
+  std::map<InstructionKind, std::shared_ptr<Instruction>>
+      instructions; // set of instructions supported by the uArch
 };
 
 // A struct to represent shared memory information
@@ -206,7 +189,7 @@ struct SharedMemory {
   SharedMemory(uint32_t size, uint32_t alignment)
       : size(size), alignment(alignment) {}
 
-  // Getters
+  // Get methods
   uint32_t getSize() const { return size; }
   uint32_t getAlignment() const { return alignment; }
 
@@ -216,6 +199,19 @@ struct SharedMemory {
   // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
 };
 
+struct XeCoreInfo {
+  uint32_t num_threads;
+  SharedMemory shared_memory;
+  uint32_t num_vector_units;
+  uint32_t num_matrix_units;
+
+  XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+             uint32_t num_vector_units, uint32_t num_matrix_units)
+      : num_threads(num_threads), shared_memory(shared_memory),
+        num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Interfaces
 //===----------------------------------------------------------------------===//

>From 2fc129a8444fb698b7e79071dc7844d9723d8218 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:13:37 +0000
Subject: [PATCH 10/13] Address review comments.

---
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index bb47a3e138e7b..64aa25ffd543f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -30,7 +30,7 @@ namespace xegpu {
 namespace uArch {
 
 // An enum class to represent the scope of an instruction
-enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
+enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
   DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
         // multiply-add operation
@@ -66,8 +66,9 @@ struct Instruction {
   InstructionScope getScope() { return scope; }
 
 protected:
-  InstructionKind instKind;
-  InstructionScope scope;
+  InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+  InstructionScope scope;   // scope of the instruction (e.g., lane, subgroup,
+                            // workgroup, cluster)
   // @TODO: Add more fields as needed
 };
 

>From 28903cb8d738be3a0c570f5cc60e1fd4ca57966b Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:53:10 +0000
Subject: [PATCH 11/13] Fix a small compile error.

---
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 64aa25ffd543f..48d2302994592 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -78,6 +78,7 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
 // A struct to represent register file information
 struct RegisterFileInfo {
   // Constructor
+  RegisterFileInfo() = default;
   RegisterFileInfo(uint32_t size,
                    const llvm::SmallVector<RegisterFileMode, 4> &mode,
                    const llvm::SmallVector<uint32_t, 4> &numRegs)
@@ -107,6 +108,7 @@ enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
 // A struct to represent cache information
 struct CacheInfo {
   // Constructor
+  CacheInfo() = default;
   CacheInfo(uint32_t size, uint32_t line_size,
             CacheHierarchyLevel hierarchy_level)
       : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}

>From 6ef33aec53a11f1eaa3191be3d5fa26d1099232c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 8 Oct 2025 12:50:31 +0000
Subject: [PATCH 12/13] Address review comments.

---
 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 48d2302994592..7341bbb0ed638 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -37,7 +37,7 @@ enum class InstructionKind {
   // @TODO: Add more instructions as needed
 };
 
-llvm::StringRef toString(InstructionKind name) {
+static llvm::StringRef toString(InstructionKind name) {
   switch (name) {
   case InstructionKind::DPAS:
     return "dpas";
@@ -45,7 +45,8 @@ llvm::StringRef toString(InstructionKind name) {
   llvm_unreachable("Unknown InstructionKind");
 }
 
-std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
+static std::optional<InstructionKind>
+parseInstructionKind(llvm::StringRef str) {
   if (str.equals_insensitive("dpas"))
     return InstructionKind::DPAS;
   return std::nullopt;

>From 1ffe77eb18c25070a4994701be8fa4365b26548c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 8 Oct 2025 13:25:32 +0000
Subject: [PATCH 13/13] Address review comments.

Make the some helper funcs to static methods.
---
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 31 +++++++++----------
 1 file changed, 15 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 7341bbb0ed638..955994ea5ecf5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -37,21 +37,6 @@ enum class InstructionKind {
   // @TODO: Add more instructions as needed
 };
 
-static llvm::StringRef toString(InstructionKind name) {
-  switch (name) {
-  case InstructionKind::DPAS:
-    return "dpas";
-  }
-  llvm_unreachable("Unknown InstructionKind");
-}
-
-static std::optional<InstructionKind>
-parseInstructionKind(llvm::StringRef str) {
-  if (str.equals_insensitive("dpas"))
-    return InstructionKind::DPAS;
-  return std::nullopt;
-}
-
 // A struct to represent basic information about an instruction.
 // The primary purpose of the Instruction struct is to provide a generic way to
 // represent information about an instruction and to use this information to
@@ -65,6 +50,20 @@ struct Instruction {
   // Get methods
   InstructionKind getInstructionKind() { return instKind; }
   InstructionScope getScope() { return scope; }
+  static llvm::StringRef toString(InstructionKind instKind) {
+    switch (instKind) {
+    case InstructionKind::DPAS:
+      return "dpas";
+    }
+    llvm_unreachable("Unknown InstructionKind");
+  }
+
+  static std::optional<InstructionKind>
+  parseInstructionKind(llvm::StringRef str) {
+    if (str.equals_insensitive("dpas"))
+      return InstructionKind::DPAS;
+    return std::nullopt;
+  }
 
 protected:
   InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
@@ -168,7 +167,7 @@ struct uArch {
   llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
     llvm::SmallVector<StringRef, 8> instructionNames;
     for (const auto &inst : instructions) {
-      instructionNames.push_back(toString(inst.first));
+      instructionNames.push_back(Instruction::toString(inst.first));
     }
     return instructionNames;
   }



More information about the Mlir-commits mailing list