[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 14 15:35:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Md Abdullah Shahneous Bari (mshahneo)

<details>
<summary>Changes</summary>

The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary components (e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops (e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces are provided by the specific instructions. Each family of ops provides these 5 common APIs. However, some family of ops may have more utility APIs. The common 5 APIs are:
	- getSupportedShapes
	- getSupportedTypes
	- checkSupportedShapesAndTypes
	- checkSupportedTypes
	- validate

Add support for PVC and BMG architectures.
Add support for DPAS instruction.

---

Patch is 32.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153706.diff


12 Files Affected:

- (added) mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h (+182) 
- (added) mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h (+266) 
- (added) mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h (+75) 
- (modified) mlir/lib/Dialect/LLVMIR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp (+9) 
- (modified) mlir/lib/Dialect/XeGPU/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt (+2-1) 
- (added) mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt (+11) 
- (added) mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp (+197) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..9179838f8c148
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,182 @@
+//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Xe2 uArch definition.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+struct XeCoreInfo {
+  uint32_t num_threads;
+  SharedMemory shared_memory;
+  uint32_t num_vector_units;
+  uint32_t num_matrix_units;
+
+  // Constructor
+  XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+             uint32_t num_vector_units, uint32_t num_matrix_units)
+      : num_threads(num_threads), shared_memory(shared_memory),
+        num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+  }
+};
+
+struct Xe2Plus : public uArch {
+  XeCoreInfo xe_core;
+  Xe2Plus(
+      const std::string &archName, const std::string &archDescription,
+      const XeCoreInfo &xeCore,
+      const std::vector<uArchHierarchyComponent> &hierarchy = {},
+      const std::map<std::string, RegisterFileInfo> &regInfo = {},
+      const std::vector<CacheInfo> &cacheInfo = {},
+      const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+      : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+        xe_core(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+  DPASInstruction()
+      : Instruction("dpas",                   // name
+                    "Dot Product Accumulate") // description
+  {}
+
+  // Override all virtuals from MatrixOpInterface
+  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+  virtual std::vector<mlir::Type>
+  getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape,
+                               mlir::Type AType, mlir::Type BType,
+                               mlir::Type CType, mlir::Type DType) override;
+  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                   mlir::Type CType, mlir::Type DType) override;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+                        mlir::Type BType, mlir::Type CType,
+                        mlir::Type DType) override;
+  virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
+  virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
+  virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+};
+
+namespace PVCuArch {
+struct PVCuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  PVCuArch()
+      : Xe2Plus("pvc",                        // archName
+                "Ponte Vecchio Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+                {/* register_file_info */}, // Optional: empty
+                {/* cache_info */},         // Optional: empty
+                {/* instructions */}        // Optional: empty
+        ) {
+    // Initialize uArchHierarchy
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
+    // Intialize register file info
+    // GRF
+    this->register_file_info.emplace(
+        "GRF",
+        RegisterFileInfo(64 * 1024,          // size in bits
+                         {"small", "large"}, // GRF modes
+                         {128, 256},         // registers per thread per mode
+                         0,                  // number of banks
+                         0                   // bank size
+                         ));
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cache_info.push_back(
+        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
+    // L3 cache, XeStack level
+    this->cache_info.push_back(
+        CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+
+    // Add the instructions
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getName(), dpas);
+    // instructions[dpas->name] = dpas.get();
+    owned_instructions.push_back(dpas);
+  }
+};
+} // namespace PVCuArch
+
+namespace BMGuArch {
+struct BMGuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  std::vector<std::shared_ptr<Instruction>> owned_instructions;
+  BMGuArch()
+      : Xe2Plus("bmg",                     // archName
+                "Battlemage Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+                {/* register_file_info */}, // Optional: empty
+                {/* cache_info */},         // Optional: empty
+                {/* instructions */},       // Optional: empty
+                {/* restrictions */}        // Optional: empty
+        ) {
+    // Initialize uArchHierarchy
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
+    this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
+    // Intialize register file info
+    // GRF
+    this->register_file_info["GRF"] =
+        RegisterFileInfo(64 * 1024,          // size in bits
+                         {"small", "large"}, // GRF modes
+                         {128, 256},         // registers per thread per mode
+                         0,                  // number of banks
+                         0                   // bank size
+        );
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cache_info.push_back(
+        CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
+    // L3 cache, XeStack level
+    this->cache_info.push_back(
+        CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+
+    // Add the instructions
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getName(), dpas);
+    // instructions[dpas->name] = dpas.get();
+    owned_instructions.push_back(dpas);
+  }
+};
+} // namespace BMGuArch
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..9bda86df2aff9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,266 @@
+//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Base uArch definition for different architectures.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+// Architecture HW component hierarchy to present thread, core, socket ...
+struct uArchHierarchyComponent {
+  std::string name = ""; // optional name of the hierarchy component
+  // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
+  // contains 8 threads, so no_of_component=8
+  uint32_t no_of_component;
+  // Constructor
+  uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
+      : name(name), no_of_component(no_of_component) {}
+};
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+
+// A struct to represent basic information about an instruction
+// This struct is used to represent the information about an instruction in the
+// uArch The information includes:
+// - the name of the instruction,
+// - the description of the instruction
+// - the scope of the instruction,
+//
+// The information is represented as strings
+// For example, the information about an instruction can be represented as:
+// Instruction instr = {"dpas", "Dot Product Accumulate Systolic  (DPAS) is a
+// matrix multiply-add operation", "subgroup"};
+
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed
+
+struct Instruction {
+  // @TODO: Add more fields as needed
+  Instruction(std::string name, std::string desc)
+      : name(std::move(name)), description(std::move(desc)) {}
+
+  virtual ~Instruction() = default;
+  // Get methods
+  std::string getName() { return name; }
+  std::string getDescription() { return description; }
+  InstructionScopeEnum getScope() { return scope; }
+
+protected:
+  std::string name;
+  std::string description;
+  InstructionScopeEnum scope;
+};
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+  // Constructor
+  RegisterFileInfo() = default;
+  RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
+                   const std::vector<uint32_t> &numRegs, uint32_t num_banks,
+                   uint32_t bank_size)
+      : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
+        num_banks(num_banks), bank_size(bank_size) {}
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+
+  const std::vector<std::string> &getModes() const { return mode; }
+
+  const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+    return num_regs_per_thread_per_mode;
+  }
+
+  uint32_t getNumBanks() const { return num_banks; }
+
+  uint32_t getBankSize() const { return bank_size; }
+
+protected:
+  uint32_t size;                 // size per register in bits
+  std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+  std::vector<uint32_t>
+      num_regs_per_thread_per_mode; // number of registers per thread per mode
+  uint32_t num_banks;
+  uint32_t bank_size;
+};
+
+// A struct to represent cache information
+
+struct CacheInfo {
+  // Constructor
+  CacheInfo(uint32_t size, uint32_t line_size,
+            const uArchHierarchyComponent &component)
+      : size(size), line_size(line_size), component(component) {}
+
+  virtual ~CacheInfo() = default;
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+  uint32_t getLineSize() const { return line_size; }
+  const uArchHierarchyComponent &getComponent() const { return component; }
+
+protected:
+  uint32_t size;
+  uint32_t line_size;
+  // At which component level the cache is shared
+  uArchHierarchyComponent component;
+
+  // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+  // bank_size, num_ports, port_width, bank_conflicts)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device
+// The uArch includes:
+// - the name of the uArch,
+// - the description of the uArch,
+// - uArch hierarchy
+// - Rgister File information
+// - Cache information
+// - the set of instructions supported by the uArch,
+struct uArch {
+  // Constructor
+  uArch() = default;
+  uArch(const std::string &name, const std::string &description,
+        const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
+        const std::map<std::string, RegisterFileInfo> &register_file_info = {},
+        const std::vector<CacheInfo> &cache_info = {},
+        const std::map<std::string, std::shared_ptr<Instruction>>
+            &instructions = {})
+      : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
+        register_file_info(register_file_info), cache_info(cache_info),
+        instructions(instructions) {}
+
+  // Get methods
+  const std::string &getName() const { return name; }
+
+  const std::string &getDescription() const { return description; }
+
+  const std::vector<uArchHierarchyComponent> &getHierarchy() const {
+    return uArch_hierarchy;
+  }
+
+  const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
+    return register_file_info;
+  }
+
+  const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+
+  const std::map<std::string, std::shared_ptr<Instruction>> &
+  getInstructions() const {
+    return instructions;
+  }
+
+  // Get the name of the supported instruction names for that
+  // architecture. It returns the names of the instructions added to the uArch.
+  std::vector<std::string> getSupportedInstructionNames() const {
+    std::vector<std::string> instructionNames;
+    for (const auto &inst : instructions) {
+      instructionNames.push_back(inst.first);
+    }
+    return instructionNames;
+  }
+
+  // Checks if an instruction is supported in this uArch
+  bool checkSupportedInstruction(const std::string &instructionName) const {
+    return instructions.find(instructionName) != instructions.end();
+  }
+
+protected:
+  std::string name; // Similar to target triple
+  std::string description;
+  std::vector<uArchHierarchyComponent> uArch_hierarchy;
+  std::map<std::string, RegisterFileInfo> register_file_info;
+  std::vector<CacheInfo> cache_info;
+  std::map<std::string, std::shared_ptr<Instruction>> instructions;
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+  // Constructor
+  SharedMemory(uint32_t size, uint32_t alignment)
+      : size(size), alignment(alignment) {}
+
+  // Getters
+  uint32_t getSize() const { return size; }
+  uint32_t getAlignment() const { return alignment; }
+
+protected:
+  uint32_t size;      // in bytes
+  uint32_t alignment; // in bytes
+  // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct uArchMap {
+public:
+  // Singleton instance
+  static uArchMap &instance() {
+    static uArchMap instance;
+    return instance;
+  }
+
+  // Insert or update a key-value pair
+  void insert(const std::string &key, std::shared_ptr<uArch> value) {
+    std::unique_lock<std::shared_mutex> lock(mutex_);
+    // map_[key] = std::move(value); // safe to overwrite
+    map_.emplace(key, std::move(value)); // safe to overwrite
+  }
+
+  // Get a value by key (concurrent safe read)
+  std::shared_ptr<uArch> get(const std::string &key) const {
+    std::shared_lock<std::shared_mutex> lock(mutex_);
+    auto it = map_.find(key);
+    if (it != map_.end())
+      return it->second;
+    return nullptr;
+  }
+
+  // Check if a key exists
+  bool contains(const std::string &key) const {
+    std::shared_lock<std::shared_mutex> lock(mutex_);
+    return map_.find(key) != map_.end();
+  }
+
+  // Remove a key
+  bool erase(const std::string &key) {
+    std::unique_lock<std::shared_mutex> lock(mutex_);
+    return map_.erase(key) > 0;
+  }
+
+private:
+  uArchMap() = default;
+  uArchMap(const uArchMap &) = delete;
+  uArchMap &operator=(const uArchMap &) = delete;
+
+  mutable std::shared_mutex mutex_;
+  std::map<std::string, std::shared_ptr<uArch>> map_;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
new file mode 100644
index 0000000000000..27d44c38317a1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -0,0 +1,75 @@
+//===--- uArchInterfaces.h ---*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Defines the utility interfaces that are implemented by individual
+/// instructions.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+  // Get supported Matrix shapes
+  virtual std::vector<std::pair<uint32_t, uint32_t>>
+  getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
+
+  // @TODO: This method takes an context object as a parameter, this is to
+  // create the mlir::Type objects from the same context. Since type objects are
+  // uniqued in a specific context, to do things like "aType == bType" (where
+  // aType and bType are both same type) kind of checks, the both types should
+  // be from the same context.
+  //
+  // One alternative to this is to create enum to represent each types, but this
+  // adds an extra burden to user to convert these enums to specific types. In
+  // fact the utility that would convert enumToType() and vice versa would still
+  // have to use the context object.
+  //
+  // Untill we have a better solution, we stick to passing context object to
+  // this method.
+  virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
+                                                    MMAOpndEnum matrixType) = 0;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape,
+                               mlir::Type AType, mlir::Type BType,
+                               mlir::Type CType, mlir::Type DType) = 0;
+  virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                   mlir::Type CType, mlir::Type DType) = 0;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/153706


More information about the Mlir-commits mailing list