[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Wed Oct 8 06:28:43 PDT 2025
https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/153706
>From b41756194fb488e9cbf7664233c2361fbd60ed8e Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Fri, 28 Mar 2025 13:04:30 +0000
Subject: [PATCH 01/13] [uArch][XeGPU] Add XeGPU uArch definition.
The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary components
(e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops
(e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces
are provided by the specific instructions. Each family of ops provides
these 5 common APIs. However, some family of ops may have more
utility APIs. The common 5 APIs are:
- getSupportedShapes
- getSupportedTypes
- checkSupportedShapesAndTypes
- checkSupportedTypes
- validate
Add support for PVC and BMG architectures.
Add support for DPAS instruction.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 182 ++++++++++++
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 266 ++++++++++++++++++
.../Dialect/XeGPU/uArch/uArchInterfaces.h | 75 +++++
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 1 +
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 9 +
mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 +
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt | 11 +
mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp | 197 +++++++++++++
11 files changed, 753 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
create mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..9179838f8c148
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,182 @@
+//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Xe2 uArch definition.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+struct XeCoreInfo {
+ uint32_t num_threads;
+ SharedMemory shared_memory;
+ uint32_t num_vector_units;
+ uint32_t num_matrix_units;
+
+ // Constructor
+ XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+ uint32_t num_vector_units, uint32_t num_matrix_units)
+ : num_threads(num_threads), shared_memory(shared_memory),
+ num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+ }
+};
+
+struct Xe2Plus : public uArch {
+ XeCoreInfo xe_core;
+ Xe2Plus(
+ const std::string &archName, const std::string &archDescription,
+ const XeCoreInfo &xeCore,
+ const std::vector<uArchHierarchyComponent> &hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®Info = {},
+ const std::vector<CacheInfo> &cacheInfo = {},
+ const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+ : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+ xe_core(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+ DPASInstruction()
+ : Instruction("dpas", // name
+ "Dot Product Accumulate") // description
+ {}
+
+ // Override all virtuals from MatrixOpInterface
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+ virtual std::vector<mlir::Type>
+ getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+ mlir::Type BType, mlir::Type CType,
+ mlir::Type DType) override;
+ virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+};
+
+namespace PVCuArch {
+struct PVCuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ PVCuArch()
+ : Xe2Plus("pvc", // archName
+ "Ponte Vecchio Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
+ // Intialize register file info
+ // GRF
+ this->register_file_info.emplace(
+ "GRF",
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ ));
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace PVCuArch
+
+namespace BMGuArch {
+struct BMGuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ BMGuArch()
+ : Xe2Plus("bmg", // archName
+ "Battlemage Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */}, // Optional: empty
+ {/* restrictions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
+ // Intialize register file info
+ // GRF
+ this->register_file_info["GRF"] =
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ );
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace BMGuArch
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..9bda86df2aff9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,266 @@
+//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Base uArch definition for different architectures.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+// Architecture HW component hierarchy to present thread, core, socket ...
+struct uArchHierarchyComponent {
+ std::string name = ""; // optional name of the hierarchy component
+ // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
+ // contains 8 threads, so no_of_component=8
+ uint32_t no_of_component;
+ // Constructor
+ uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
+ : name(name), no_of_component(no_of_component) {}
+};
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+
+// A struct to represent basic information about an instruction
+// This struct is used to represent the information about an instruction in the
+// uArch The information includes:
+// - the name of the instruction,
+// - the description of the instruction
+// - the scope of the instruction,
+//
+// The information is represented as strings
+// For example, the information about an instruction can be represented as:
+// Instruction instr = {"dpas", "Dot Product Accumulate Systolic (DPAS) is a
+// matrix multiply-add operation", "subgroup"};
+
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed
+
+struct Instruction {
+ // @TODO: Add more fields as needed
+ Instruction(std::string name, std::string desc)
+ : name(std::move(name)), description(std::move(desc)) {}
+
+ virtual ~Instruction() = default;
+ // Get methods
+ std::string getName() { return name; }
+ std::string getDescription() { return description; }
+ InstructionScopeEnum getScope() { return scope; }
+
+protected:
+ std::string name;
+ std::string description;
+ InstructionScopeEnum scope;
+};
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+ // Constructor
+ RegisterFileInfo() = default;
+ RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
+ const std::vector<uint32_t> &numRegs, uint32_t num_banks,
+ uint32_t bank_size)
+ : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
+ num_banks(num_banks), bank_size(bank_size) {}
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+
+ const std::vector<std::string> &getModes() const { return mode; }
+
+ const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+ return num_regs_per_thread_per_mode;
+ }
+
+ uint32_t getNumBanks() const { return num_banks; }
+
+ uint32_t getBankSize() const { return bank_size; }
+
+protected:
+ uint32_t size; // size per register in bits
+ std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+ std::vector<uint32_t>
+ num_regs_per_thread_per_mode; // number of registers per thread per mode
+ uint32_t num_banks;
+ uint32_t bank_size;
+};
+
+// A struct to represent cache information
+
+struct CacheInfo {
+ // Constructor
+ CacheInfo(uint32_t size, uint32_t line_size,
+ const uArchHierarchyComponent &component)
+ : size(size), line_size(line_size), component(component) {}
+
+ virtual ~CacheInfo() = default;
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+ uint32_t getLineSize() const { return line_size; }
+ const uArchHierarchyComponent &getComponent() const { return component; }
+
+protected:
+ uint32_t size;
+ uint32_t line_size;
+ // At which component level the cache is shared
+ uArchHierarchyComponent component;
+
+ // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+ // bank_size, num_ports, port_width, bank_conflicts)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device
+// The uArch includes:
+// - the name of the uArch,
+// - the description of the uArch,
+// - uArch hierarchy
+// - Rgister File information
+// - Cache information
+// - the set of instructions supported by the uArch,
+struct uArch {
+ // Constructor
+ uArch() = default;
+ uArch(const std::string &name, const std::string &description,
+ const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®ister_file_info = {},
+ const std::vector<CacheInfo> &cache_info = {},
+ const std::map<std::string, std::shared_ptr<Instruction>>
+ &instructions = {})
+ : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
+ register_file_info(register_file_info), cache_info(cache_info),
+ instructions(instructions) {}
+
+ // Get methods
+ const std::string &getName() const { return name; }
+
+ const std::string &getDescription() const { return description; }
+
+ const std::vector<uArchHierarchyComponent> &getHierarchy() const {
+ return uArch_hierarchy;
+ }
+
+ const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
+ return register_file_info;
+ }
+
+ const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+
+ const std::map<std::string, std::shared_ptr<Instruction>> &
+ getInstructions() const {
+ return instructions;
+ }
+
+ // Get the name of the supported instruction names for that
+ // architecture. It returns the names of the instructions added to the uArch.
+ std::vector<std::string> getSupportedInstructionNames() const {
+ std::vector<std::string> instructionNames;
+ for (const auto &inst : instructions) {
+ instructionNames.push_back(inst.first);
+ }
+ return instructionNames;
+ }
+
+ // Checks if an instruction is supported in this uArch
+ bool checkSupportedInstruction(const std::string &instructionName) const {
+ return instructions.find(instructionName) != instructions.end();
+ }
+
+protected:
+ std::string name; // Similar to target triple
+ std::string description;
+ std::vector<uArchHierarchyComponent> uArch_hierarchy;
+ std::map<std::string, RegisterFileInfo> register_file_info;
+ std::vector<CacheInfo> cache_info;
+ std::map<std::string, std::shared_ptr<Instruction>> instructions;
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+ // Constructor
+ SharedMemory(uint32_t size, uint32_t alignment)
+ : size(size), alignment(alignment) {}
+
+ // Getters
+ uint32_t getSize() const { return size; }
+ uint32_t getAlignment() const { return alignment; }
+
+protected:
+ uint32_t size; // in bytes
+ uint32_t alignment; // in bytes
+ // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct uArchMap {
+public:
+ // Singleton instance
+ static uArchMap &instance() {
+ static uArchMap instance;
+ return instance;
+ }
+
+ // Insert or update a key-value pair
+ void insert(const std::string &key, std::shared_ptr<uArch> value) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ // map_[key] = std::move(value); // safe to overwrite
+ map_.emplace(key, std::move(value)); // safe to overwrite
+ }
+
+ // Get a value by key (concurrent safe read)
+ std::shared_ptr<uArch> get(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ auto it = map_.find(key);
+ if (it != map_.end())
+ return it->second;
+ return nullptr;
+ }
+
+ // Check if a key exists
+ bool contains(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ return map_.find(key) != map_.end();
+ }
+
+ // Remove a key
+ bool erase(const std::string &key) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ return map_.erase(key) > 0;
+ }
+
+private:
+ uArchMap() = default;
+ uArchMap(const uArchMap &) = delete;
+ uArchMap &operator=(const uArchMap &) = delete;
+
+ mutable std::shared_mutex mutex_;
+ std::map<std::string, std::shared_ptr<uArch>> map_;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
new file mode 100644
index 0000000000000..27d44c38317a1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -0,0 +1,75 @@
+//===--- uArchInterfaces.h ---*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Defines the utility interfaces that are implemented by individual
+/// instructions.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+ // Get supported Matrix shapes
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
+
+ // @TODO: This method takes an context object as a parameter, this is to
+ // create the mlir::Type objects from the same context. Since type objects are
+ // uniqued in a specific context, to do things like "aType == bType" (where
+ // aType and bType are both same type) kind of checks, the both types should
+ // be from the same context.
+ //
+ // One alternative to this is to create enum to represent each types, but this
+ // adds an extra burden to user to convert these enums to specific types. In
+ // fact the utility that would convert enumToType() and vice versa would still
+ // have to use the context object.
+ //
+ // Untill we have a better solution, we stick to passing context object to
+ // this method.
+ virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) = 0;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+ mlir::Type BType, mlir::Type CType,
+ mlir::Type DType) = 0;
+ virtual std::vector<uint32_t> getSupportedM(mlir::Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedK(mlir::Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedN(mlir::Type type) = 0;
+
+ virtual ~MMAInstructionInterface() = default;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac7277e3..9e40d41de1c73 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -129,5 +129,6 @@ add_mlir_dialect_library(MLIRXeVMDialect
MLIRDialectUtils
MLIRIR
MLIRLLVMDialect
+ MLIRXeGPUuArch
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 04e8836c00359..37ab1fcdd1c0e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -380,6 +381,14 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
}
void XeVMDialect::initialize() {
+ // Populate the uArchMap with the supported target devices
+ auto pvcuArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+ auto bmguArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..9079df050ab2b 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(uArch)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7869a28dfed57..e1c51deefbee3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRGPUDialect
MLIRXeVMDialect
MLIRIR
+ MLIRXeGPUuArch
MLIRViewLikeInterface
MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..359517432d82e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -35,6 +36,14 @@ void XeGPUDialect::initialize() {
#define GET_ATTRDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
>();
+
+ // Populate the uArchMap with the supported target devices
+ auto pvcuArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+ auto bmguArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
}
/// Generates instructions to compute offsets for a subgroup identified by
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..e56fe3a1e193d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRTransforms
MLIRGPUDialect
MLIRXeGPUUtils
+ MLIRXeGPUuArch
MLIRGPUUtils
MLIRVectorTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
new file mode 100644
index 0000000000000..c7f691cb6dda7
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRXeGPUuArch
+ IntelGpuXe2.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRDialectUtils
+)
+
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
new file mode 100644
index 0000000000000..d80d1439692e8
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -0,0 +1,197 @@
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/YAMLTraits.h"
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace mlir::xegpu::uArch;
+using namespace mlir::xegpu::uArch::Xe2Plus;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+
+std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(mlir::Type dataType,
+ MMAOpndEnum matrixType) {
+ auto combineVectors = [](const std::vector<uint32_t> &a,
+ const std::vector<uint32_t> &b)
+ -> std::vector<std::pair<uint32_t, uint32_t>> {
+ std::vector<std::pair<uint32_t, uint32_t>> result;
+ for (unsigned x : a) {
+ for (unsigned y : b) {
+ result.emplace_back(x, y);
+ }
+ }
+ return result;
+ };
+
+ auto M = getSupportedM(dataType);
+ auto K = getSupportedK(dataType);
+ auto N = getSupportedN(dataType);
+ std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ resultMatrix = combineVectors(M, K);
+ break;
+ case MMAOpndEnum::MatrixB:
+ resultMatrix = combineVectors(K, N);
+ break;
+ case MMAOpndEnum::MatrixC:
+ resultMatrix = combineVectors(M, N);
+ break;
+ case MMAOpndEnum::MatrixD:
+ resultMatrix = combineVectors(M, N);
+ break;
+ }
+ return resultMatrix;
+}
+
+std::vector<mlir::Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) {
+ mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
+ mlir::Type f16Type = mlir::Float16Type::get(&context);
+ mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
+ mlir::Type f32Type = mlir::Float32Type::get(&context);
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixB:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixC:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ case MMAOpndEnum::MatrixD:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ }
+}
+
+bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) {
+ if (AType.isF16() || BType.isF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+ (!DType.isF32() && !DType.isF16())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f, hf | f, hf | hf | hf \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (AType.isBF16() || BType.isBF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+ (!DType.isF32() && !DType.isBF16())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f, bf | f, bf | bf | bf \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (AType.isTF32() || BType.isTF32()) {
+ if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+ (!DType.isF32())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f | f | tf32 | tf32 \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+ AType.isInteger(8)) &&
+ !(BType.isInteger(2) || BType.isInteger(4) ||
+ BType.isInteger(8))) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B "
+ " \n"
+ << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 "
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+
+ return true;
+}
+
+bool DPASInstruction::checkSupportedShapesAndTypes(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
+ auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
+ auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
+ auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
+ auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
+ return llvm::is_contained(supportedAShapes, AShape) &&
+ llvm::is_contained(supportedBShapes, BShape) &&
+ llvm::is_contained(supportedCShapes, CShape) &&
+ llvm::is_contained(supportedDShapes, DShape) &&
+ checkSupportedTypes(AType, BType, CType, DType);
+}
+
+bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) {
+ return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+ BType, CType, DType);
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedM(mlir::Type type) {
+ return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedK(mlir::Type type) {
+ // assert if data type is not int or float type
+ assert(type.isIntOrFloat() && "Matrix type must be int or float");
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ uint32_t kSize = 0;
+ switch (bitWidth) {
+ case 2:
+ kSize = 64;
+ break;
+ case 4:
+ kSize = 64;
+ break;
+ case 8:
+ kSize = 32;
+ break;
+ case 16:
+ kSize = 16;
+ break;
+ case 32:
+ kSize = 8;
+ break;
+ default:
+ llvm_unreachable("Invalid int or float");
+ }
+ return {kSize};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedN(mlir::Type type) {
+ return {16};
+}
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
>From 6dace4bcb469f5813b512aba5cef3df48ebb7a6d Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 26 Aug 2025 00:08:11 +0000
Subject: [PATCH 02/13] Address review comments.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 15 +++----
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 14 +++----
mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp | 39 ++++---------------
3 files changed, 22 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 9179838f8c148..e49fe69c47b2e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -1,4 +1,4 @@
-//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,14 @@
//
//===----------------------------------------------------------------------===//
//
-/// \file
-/// Xe2 uArch definition.
-///
+// \file
+// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
+// This file defines the uArch details for Xe2 and its derived architectures.
+// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
-#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -179,4 +180,4 @@ struct BMGuArch : public Xe2Plus {
} // namespace xegpu
} // namespace mlir
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 9bda86df2aff9..4ef02f32913f7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -1,4 +1,4 @@
-//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//===--- uArch.h ------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
-/// \file
-/// Base uArch definition for different architectures.
-///
+// \file
+// Base uArch definition for different architectures.
+//
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
-#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
#include <any>
#include <functional>
@@ -263,4 +263,4 @@ struct uArchMap {
} // namespace xegpu
} // namespace mlir
-#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
index d80d1439692e8..4db4300028c46 100644
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -1,11 +1,11 @@
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "llvm/Support/YAMLTraits.h"
+#include "llvm/Support/DebugLog.h"
#include <algorithm>
-#include <iostream>
-#include <string>
#include <vector>
+#define DEBUG_TYPE "xegpu-uarch"
+
using namespace mlir::xegpu::uArch;
using namespace mlir::xegpu::uArch::Xe2Plus;
@@ -80,51 +80,26 @@ bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
if (AType.isF16() || BType.isF16()) {
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
(!DType.isF32() && !DType.isF16())) {
- llvm::errs()
- << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
- << "Supported types are:\n"
- << " Dst | Acc | A | B \n"
- << " f, hf | f, hf | hf | hf \n"
- << "AType: " << AType << " BType: " << BType << " CType: " << CType
- << " DType: " << DType;
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
return false;
}
} else if (AType.isBF16() || BType.isBF16()) {
if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
(!DType.isF32() && !DType.isBF16())) {
- llvm::errs()
- << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
- << "Supported types are:\n"
- << " Dst | Acc | A | B \n"
- << " f, bf | f, bf | bf | bf \n"
- << "AType: " << AType << " BType: " << BType << " CType: " << CType
- << " DType: " << DType;
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
return false;
}
} else if (AType.isTF32() || BType.isTF32()) {
if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
(!DType.isF32())) {
- llvm::errs()
- << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
- << "Supported types are:\n"
- << " Dst | Acc | A | B \n"
- << " f | f | tf32 | tf32 \n"
- << "AType: " << AType << " BType: " << BType << " CType: " << CType
- << " DType: " << DType;
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
return false;
}
} else if (!(AType.isInteger(2) || AType.isInteger(4) ||
AType.isInteger(8)) &&
!(BType.isInteger(2) || BType.isInteger(4) ||
BType.isInteger(8))) {
- llvm::errs()
- << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
- << "Supported types are:\n"
- << " Dst | Acc | A | B "
- " \n"
- << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 "
- << "AType: " << AType << " BType: " << BType << " CType: " << CType
- << " DType: " << DType;
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
return false;
}
>From f33b7f73840a3d0fd2c3c1c725a72da5374c6c56 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 17:10:46 +0000
Subject: [PATCH 03/13] Address review comments.
Simplify the design:
- Remove uArchHierarchyComponent
LLVMize names.
Replace String usage with enum whenever possible.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 71 ++++++--------
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 93 ++++++++-----------
.../Dialect/XeGPU/uArch/uArchInterfaces.h | 7 +-
mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp | 28 +++---
4 files changed, 83 insertions(+), 116 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index e49fe69c47b2e..a7ae0bdf4378b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -45,11 +45,10 @@ struct Xe2Plus : public uArch {
Xe2Plus(
const std::string &archName, const std::string &archDescription,
const XeCoreInfo &xeCore,
- const std::vector<uArchHierarchyComponent> &hierarchy = {},
- const std::map<std::string, RegisterFileInfo> ®Info = {},
+ const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
const std::vector<CacheInfo> &cacheInfo = {},
const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
- : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+ : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xe_core(xeCore) {}
};
@@ -62,9 +61,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
// Override all virtuals from MatrixOpInterface
virtual std::vector<std::pair<uint32_t, uint32_t>>
- getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+ getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
virtual std::vector<mlir::Type>
- getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+ getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
virtual bool
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
@@ -97,29 +96,22 @@ struct PVCuArch : public Xe2Plus {
{/* cache_info */}, // Optional: empty
{/* instructions */} // Optional: empty
) {
- // Initialize uArchHierarchy
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
// Intialize register file info
// GRF
- this->register_file_info.emplace(
- "GRF",
- RegisterFileInfo(64 * 1024, // size in bits
- {"small", "large"}, // GRF modes
- {128, 256}, // registers per thread per mode
- 0, // number of banks
- 0 // bank size
- ));
+ this->registerFileInfo.emplace(
+ RegisterFileType::GRF,
+ RegisterFileInfo(
+ 64 * 1024, // size in bits
+ {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+ {128, 256} // registers per thread per mode
+ ));
// Initialize cache info
// L1 cache, XeCore level
- this->cache_info.push_back(
- CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
- // L3 cache, XeStack level
- this->cache_info.push_back(
- CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+ this->cacheInfo.push_back(
+ CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
+ // L2 cache, XeStack level
+ this->cacheInfo.push_back(
+ CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
// Add the instructions
auto dpas = std::make_shared<DPASInstruction>();
@@ -140,31 +132,22 @@ struct BMGuArch : public Xe2Plus {
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
{/* register_file_info */}, // Optional: empty
{/* cache_info */}, // Optional: empty
- {/* instructions */}, // Optional: empty
- {/* restrictions */} // Optional: empty
+ {/* instructions */} // Optional: empty)
) {
- // Initialize uArchHierarchy
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
- this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
// Intialize register file info
// GRF
- this->register_file_info["GRF"] =
- RegisterFileInfo(64 * 1024, // size in bits
- {"small", "large"}, // GRF modes
- {128, 256}, // registers per thread per mode
- 0, // number of banks
- 0 // bank size
- );
+ this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
+ 64 * 1024, // size in bits
+ {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+ {128, 256} // registers per thread per mode
+ );
// Initialize cache info
// L1 cache, XeCore level
- this->cache_info.push_back(
- CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
- // L3 cache, XeStack level
- this->cache_info.push_back(
- CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+ this->cacheInfo.push_back(
+ CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
+ // L2 cache, XeStack level
+ this->cacheInfo.push_back(
+ CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
// Add the instructions
auto dpas = std::make_shared<DPASInstruction>();
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 4ef02f32913f7..2416b173e505a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -27,19 +27,15 @@
namespace mlir {
namespace xegpu {
namespace uArch {
-// Architecture HW component hierarchy to present thread, core, socket ...
-struct uArchHierarchyComponent {
- std::string name = ""; // optional name of the hierarchy component
- // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
- // contains 8 threads, so no_of_component=8
- uint32_t no_of_component;
- // Constructor
- uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
- : name(name), no_of_component(no_of_component) {}
-};
// An enum class to represent the scope of an instruction
-enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
+
+enum class InstructionName {
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
+ // operation
+ // Add more instructions as needed
+};
// A struct to represent basic information about an instruction
// This struct is used to represent the information about an instruction in the
@@ -67,69 +63,62 @@ struct Instruction {
// Get methods
std::string getName() { return name; }
std::string getDescription() { return description; }
- InstructionScopeEnum getScope() { return scope; }
+ InstructionScope getScope() { return scope; }
protected:
std::string name;
std::string description;
- InstructionScopeEnum scope;
+ InstructionScope scope;
};
+enum class RegisterFileMode : uint8_t { Small, Large };
+enum class RegisterFileType : uint8_t { GRF, ARF };
+
// A struct to represent register file information
struct RegisterFileInfo {
// Constructor
RegisterFileInfo() = default;
- RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
- const std::vector<uint32_t> &numRegs, uint32_t num_banks,
- uint32_t bank_size)
- : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
- num_banks(num_banks), bank_size(bank_size) {}
+ RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
+ const std::vector<uint32_t> &numRegs)
+ : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
- // Get methods
uint32_t getSize() const { return size; }
-
- const std::vector<std::string> &getModes() const { return mode; }
-
+ const std::vector<RegisterFileMode> &getModes() const { return mode; }
const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
- return num_regs_per_thread_per_mode;
+ return numRegsPerThreadPerMode;
}
- uint32_t getNumBanks() const { return num_banks; }
-
- uint32_t getBankSize() const { return bank_size; }
-
protected:
- uint32_t size; // size per register in bits
- std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+ uint32_t size; // size per register in bits
+ std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
std::vector<uint32_t>
- num_regs_per_thread_per_mode; // number of registers per thread per mode
- uint32_t num_banks;
- uint32_t bank_size;
+ numRegsPerThreadPerMode; // number of registers per thread per mode
+ // TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
+ // port_width, bank_conflicts)
};
+enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
// A struct to represent cache information
-
struct CacheInfo {
// Constructor
CacheInfo(uint32_t size, uint32_t line_size,
- const uArchHierarchyComponent &component)
- : size(size), line_size(line_size), component(component) {}
+ CacheHierarchyLevel hierarchy_level)
+ : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
virtual ~CacheInfo() = default;
// Get methods
uint32_t getSize() const { return size; }
uint32_t getLineSize() const { return line_size; }
- const uArchHierarchyComponent &getComponent() const { return component; }
+ CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; }
protected:
uint32_t size;
uint32_t line_size;
- // At which component level the cache is shared
- uArchHierarchyComponent component;
-
+ CacheHierarchyLevel hierarchy_level;
// @TODO: Add more fields as needed (e.g., associativity, num_banks,
- // bank_size, num_ports, port_width, bank_conflicts)
+ // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
+ // latency, throughput, bandwidth)
};
// A struct to represent the uArch
@@ -145,13 +134,13 @@ struct uArch {
// Constructor
uArch() = default;
uArch(const std::string &name, const std::string &description,
- const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
- const std::map<std::string, RegisterFileInfo> ®ister_file_info = {},
+ const std::map<RegisterFileType, RegisterFileInfo> ®ister_file_info =
+ {},
const std::vector<CacheInfo> &cache_info = {},
const std::map<std::string, std::shared_ptr<Instruction>>
&instructions = {})
- : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
- register_file_info(register_file_info), cache_info(cache_info),
+ : name(name), description(description),
+ registerFileInfo(register_file_info), cacheInfo(cache_info),
instructions(instructions) {}
// Get methods
@@ -159,15 +148,12 @@ struct uArch {
const std::string &getDescription() const { return description; }
- const std::vector<uArchHierarchyComponent> &getHierarchy() const {
- return uArch_hierarchy;
- }
-
- const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
- return register_file_info;
+ const std::map<RegisterFileType, RegisterFileInfo> &
+ getRegisterFileInfo() const {
+ return registerFileInfo;
}
- const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+ const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
const std::map<std::string, std::shared_ptr<Instruction>> &
getInstructions() const {
@@ -192,9 +178,8 @@ struct uArch {
protected:
std::string name; // Similar to target triple
std::string description;
- std::vector<uArchHierarchyComponent> uArch_hierarchy;
- std::map<std::string, RegisterFileInfo> register_file_info;
- std::vector<CacheInfo> cache_info;
+ std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
+ std::vector<CacheInfo> cacheInfo;
std::map<std::string, std::shared_ptr<Instruction>> instructions;
};
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
index 27d44c38317a1..087313c357476 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -26,12 +26,11 @@ namespace mlir {
namespace xegpu {
namespace uArch {
-enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
struct MMAInstructionInterface {
// Get supported Matrix shapes
virtual std::vector<std::pair<uint32_t, uint32_t>>
- getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
-
+ getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) = 0;
// @TODO: This method takes an context object as a parameter, this is to
// create the mlir::Type objects from the same context. Since type objects are
// uniqued in a specific context, to do things like "aType == bType" (where
@@ -46,7 +45,7 @@ struct MMAInstructionInterface {
// Untill we have a better solution, we stick to passing context object to
// this method.
virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
- MMAOpndEnum matrixType) = 0;
+ MMAOpndKind matrixType) = 0;
virtual bool
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
index 4db4300028c46..e19df1e5235ba 100644
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -16,7 +16,7 @@ namespace Xe2Plus {
std::vector<std::pair<uint32_t, uint32_t>>
DPASInstruction::getSupportedShapes(mlir::Type dataType,
- MMAOpndEnum matrixType) {
+ MMAOpndKind matrixType) {
auto combineVectors = [](const std::vector<uint32_t> &a,
const std::vector<uint32_t> &b)
-> std::vector<std::pair<uint32_t, uint32_t>> {
@@ -35,16 +35,16 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
std::vector<std::pair<unsigned, unsigned>> resultMatrix;
switch (matrixType) {
- case MMAOpndEnum::MatrixA:
+ case MMAOpndKind::MatrixA:
resultMatrix = combineVectors(M, K);
break;
- case MMAOpndEnum::MatrixB:
+ case MMAOpndKind::MatrixB:
resultMatrix = combineVectors(K, N);
break;
- case MMAOpndEnum::MatrixC:
+ case MMAOpndKind::MatrixC:
resultMatrix = combineVectors(M, N);
break;
- case MMAOpndEnum::MatrixD:
+ case MMAOpndKind::MatrixD:
resultMatrix = combineVectors(M, N);
break;
}
@@ -53,23 +53,23 @@ DPASInstruction::getSupportedShapes(mlir::Type dataType,
std::vector<mlir::Type>
DPASInstruction::getSupportedTypes(MLIRContext &context,
- MMAOpndEnum matrixType) {
+ MMAOpndKind matrixType) {
mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
mlir::Type f16Type = mlir::Float16Type::get(&context);
mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
mlir::Type f32Type = mlir::Float32Type::get(&context);
switch (matrixType) {
- case MMAOpndEnum::MatrixA:
+ case MMAOpndKind::MatrixA:
return {bf16Type, f16Type, tf32Type};
break;
- case MMAOpndEnum::MatrixB:
+ case MMAOpndKind::MatrixB:
return {bf16Type, f16Type, tf32Type};
break;
- case MMAOpndEnum::MatrixC:
+ case MMAOpndKind::MatrixC:
return {bf16Type, f16Type, f32Type};
break;
- case MMAOpndEnum::MatrixD:
+ case MMAOpndKind::MatrixD:
return {bf16Type, f16Type, f32Type};
break;
}
@@ -110,10 +110,10 @@ bool DPASInstruction::checkSupportedShapesAndTypes(
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
- auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
- auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
- auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
- auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
+ auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+ auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+ auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+ auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
return llvm::is_contained(supportedAShapes, AShape) &&
llvm::is_contained(supportedBShapes, BShape) &&
llvm::is_contained(supportedCShapes, CShape) &&
>From bba25dd7c6cf23042b8d19f19a8b8d420e03ce5d Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 17:43:59 +0000
Subject: [PATCH 04/13] Address review comments.
Simplify design:
- Remove dialect initialization and necessary mechanism.
---
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 45 -------------------
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 8 ----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 8 ----
3 files changed, 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 2416b173e505a..ed9e842c56b62 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -199,51 +199,6 @@ struct SharedMemory {
// @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
};
-struct uArchMap {
-public:
- // Singleton instance
- static uArchMap &instance() {
- static uArchMap instance;
- return instance;
- }
-
- // Insert or update a key-value pair
- void insert(const std::string &key, std::shared_ptr<uArch> value) {
- std::unique_lock<std::shared_mutex> lock(mutex_);
- // map_[key] = std::move(value); // safe to overwrite
- map_.emplace(key, std::move(value)); // safe to overwrite
- }
-
- // Get a value by key (concurrent safe read)
- std::shared_ptr<uArch> get(const std::string &key) const {
- std::shared_lock<std::shared_mutex> lock(mutex_);
- auto it = map_.find(key);
- if (it != map_.end())
- return it->second;
- return nullptr;
- }
-
- // Check if a key exists
- bool contains(const std::string &key) const {
- std::shared_lock<std::shared_mutex> lock(mutex_);
- return map_.find(key) != map_.end();
- }
-
- // Remove a key
- bool erase(const std::string &key) {
- std::unique_lock<std::shared_mutex> lock(mutex_);
- return map_.erase(key) > 0;
- }
-
-private:
- uArchMap() = default;
- uArchMap(const uArchMap &) = delete;
- uArchMap &operator=(const uArchMap &) = delete;
-
- mutable std::shared_mutex mutex_;
- std::map<std::string, std::shared_ptr<uArch>> map_;
-};
-
} // namespace uArch
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 37ab1fcdd1c0e..6af69bb549bea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -381,14 +381,6 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
}
void XeVMDialect::initialize() {
- // Populate the uArchMap with the supported target devices
- auto pvcuArch =
- std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
- mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
- auto bmguArch =
- std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
- mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
-
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 359517432d82e..9beb22d517473 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -36,14 +36,6 @@ void XeGPUDialect::initialize() {
#define GET_ATTRDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
>();
-
- // Populate the uArchMap with the supported target devices
- auto pvcuArch =
- std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
- mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
- auto bmguArch =
- std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
- mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
}
/// Generates instructions to compute offsets for a subgroup identified by
>From b0e6f3400e265c35491c7b4931c9e34191cffde6 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 24 Sep 2025 23:22:18 +0000
Subject: [PATCH 05/13] Address review comments.
Move all the implementation to the .h file.
Move uArchInterfaces to uArchBase.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 192 ++++++++++++++++--
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 43 ++++
.../Dialect/XeGPU/uArch/uArchInterfaces.h | 74 -------
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 1 -
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 1 -
mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 -
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 1 -
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 -
mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt | 11 -
mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp | 172 ----------------
10 files changed, 213 insertions(+), 284 deletions(-)
delete mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
delete mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
delete mode 100644 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index a7ae0bdf4378b..7b1381e9efc2d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -15,17 +15,22 @@
#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
-#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/Support/DebugLog.h"
#include <map>
#include <string>
#include <vector>
+#define DEBUG_TYPE "xegpu-uarch"
+
+using namespace mlir;
+using namespace mlir::xegpu::uArch;
+
namespace mlir {
namespace xegpu {
namespace uArch {
-namespace Xe2Plus {
struct XeCoreInfo {
uint32_t num_threads;
SharedMemory shared_memory;
@@ -61,30 +66,27 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
// Override all virtuals from MatrixOpInterface
virtual std::vector<std::pair<uint32_t, uint32_t>>
- getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) override;
- virtual std::vector<mlir::Type>
- getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
+ getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
+ virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) override;
virtual bool
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape,
- mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) override;
- virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) override;
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) override;
+ virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+ Type DType) override;
virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
- mlir::Type BType, mlir::Type CType,
- mlir::Type DType) override;
- virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
- virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
- virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) override;
+ virtual std::vector<uint32_t> getSupportedM(Type type) override;
+ virtual std::vector<uint32_t> getSupportedK(Type type) override;
+ virtual std::vector<uint32_t> getSupportedN(Type type) override;
};
-namespace PVCuArch {
struct PVCuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -120,9 +122,7 @@ struct PVCuArch : public Xe2Plus {
owned_instructions.push_back(dpas);
}
};
-} // namespace PVCuArch
-namespace BMGuArch {
struct BMGuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
std::vector<std::shared_ptr<Instruction>> owned_instructions;
@@ -156,11 +156,159 @@ struct BMGuArch : public Xe2Plus {
owned_instructions.push_back(dpas);
}
};
-} // namespace BMGuArch
-
-} // namespace Xe2Plus
} // namespace uArch
} // namespace xegpu
} // namespace mlir
+inline std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+ auto combineVectors = [](const std::vector<uint32_t> &a,
+ const std::vector<uint32_t> &b)
+ -> std::vector<std::pair<uint32_t, uint32_t>> {
+ std::vector<std::pair<uint32_t, uint32_t>> result;
+ for (unsigned x : a) {
+ for (unsigned y : b) {
+ result.emplace_back(x, y);
+ }
+ }
+ return result;
+ };
+
+ auto M = getSupportedM(dataType);
+ auto K = getSupportedK(dataType);
+ auto N = getSupportedN(dataType);
+ std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+ switch (matrixType) {
+ case MMAOpndKind::MatrixA:
+ resultMatrix = combineVectors(M, K);
+ break;
+ case MMAOpndKind::MatrixB:
+ resultMatrix = combineVectors(K, N);
+ break;
+ case MMAOpndKind::MatrixC:
+ resultMatrix = combineVectors(M, N);
+ break;
+ case MMAOpndKind::MatrixD:
+ resultMatrix = combineVectors(M, N);
+ break;
+ }
+ return resultMatrix;
+}
+
+inline std::vector<Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) {
+ Type bf16Type = BFloat16Type::get(&context);
+ Type f16Type = Float16Type::get(&context);
+ Type tf32Type = FloatTF32Type::get(&context);
+ Type f32Type = Float32Type::get(&context);
+
+ switch (matrixType) {
+ case MMAOpndKind::MatrixA:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndKind::MatrixB:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndKind::MatrixC:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ case MMAOpndKind::MatrixD:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ }
+}
+
+inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
+ Type CType, Type DType) {
+ if (AType.isF16() || BType.isF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+ (!DType.isF32() && !DType.isF16())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (AType.isBF16() || BType.isBF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+ (!DType.isF32() && !DType.isBF16())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (AType.isTF32() || BType.isTF32()) {
+ if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+ (!DType.isF32())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+ AType.isInteger(8)) &&
+ !(BType.isInteger(2) || BType.isInteger(4) ||
+ BType.isInteger(8))) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+
+ return true;
+}
+
+inline bool DPASInstruction::checkSupportedShapesAndTypes(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType, Type DType) {
+ auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+ auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+ auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+ auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
+ return llvm::is_contained(supportedAShapes, AShape) &&
+ llvm::is_contained(supportedBShapes, BShape) &&
+ llvm::is_contained(supportedCShapes, CShape) &&
+ llvm::is_contained(supportedDShapes, DShape) &&
+ checkSupportedTypes(AType, BType, CType, DType);
+}
+
+inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType,
+ Type DType) {
+ return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+ BType, CType, DType);
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
+ return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
+ // assert if data type is not int or float type
+ assert(type.isIntOrFloat() && "Matrix type must be int or float");
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ uint32_t kSize = 0;
+ switch (bitWidth) {
+ case 2:
+ kSize = 64;
+ break;
+ case 4:
+ kSize = 64;
+ break;
+ case 8:
+ kSize = 32;
+ break;
+ case 16:
+ kSize = 16;
+ break;
+ case 32:
+ kSize = 8;
+ break;
+ default:
+ llvm_unreachable("Invalid int or float");
+ }
+ return {kSize};
+}
+
+inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
+ return {16};
+}
+
#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index ed9e842c56b62..7c76157fa2a1a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -199,6 +199,49 @@ struct SharedMemory {
// @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
};
+//===----------------------------------------------------------------------===//
+// Interfaces
+//===----------------------------------------------------------------------===//
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+ // Get supported Matrix shapes
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
+ // @TODO: This method takes an context object as a parameter, this is to
+ // create the Type objects from the same context. Since type objects are
+ // uniqued in a specific context, to do things like "aType == bType" (where
+ // aType and bType are both same type) kind of checks, the both types should
+ // be from the same context.
+ //
+ // One alternative to this is to create enum to represent each types, but this
+ // adds an extra burden to user to convert these enums to specific types. In
+ // fact the utility that would convert enumToType() and vice versa would still
+ // have to use the context object.
+ //
+ // Untill we have a better solution, we stick to passing context object to
+ // this method.
+ virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) = 0;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) = 0;
+ virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+ Type DType) = 0;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) = 0;
+ virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
+
+ virtual ~MMAInstructionInterface() = default;
+};
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
deleted file mode 100644
index 087313c357476..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
+++ /dev/null
@@ -1,74 +0,0 @@
-//===--- uArchInterfaces.h ---*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-/// \file
-/// Defines the utility interfaces that are implemented by individual
-/// instructions.
-///
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
-#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
-
-#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/TypeUtilities.h"
-#include <map>
-#include <string>
-#include <vector>
-
-namespace mlir {
-namespace xegpu {
-namespace uArch {
-
-enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
-struct MMAInstructionInterface {
- // Get supported Matrix shapes
- virtual std::vector<std::pair<uint32_t, uint32_t>>
- getSupportedShapes(mlir::Type dataType, MMAOpndKind matrixType) = 0;
- // @TODO: This method takes an context object as a parameter, this is to
- // create the mlir::Type objects from the same context. Since type objects are
- // uniqued in a specific context, to do things like "aType == bType" (where
- // aType and bType are both same type) kind of checks, the both types should
- // be from the same context.
- //
- // One alternative to this is to create enum to represent each types, but this
- // adds an extra burden to user to convert these enums to specific types. In
- // fact the utility that would convert enumToType() and vice versa would still
- // have to use the context object.
- //
- // Untill we have a better solution, we stick to passing context object to
- // this method.
- virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) = 0;
- virtual bool
- checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
- std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape,
- mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) = 0;
- virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) = 0;
- virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
- std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
- mlir::Type BType, mlir::Type CType,
- mlir::Type DType) = 0;
- virtual std::vector<uint32_t> getSupportedM(mlir::Type type) = 0;
- virtual std::vector<uint32_t> getSupportedK(mlir::Type type) = 0;
- virtual std::vector<uint32_t> getSupportedN(mlir::Type type) = 0;
-
- virtual ~MMAInstructionInterface() = default;
-};
-
-} // namespace uArch
-} // namespace xegpu
-} // namespace mlir
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 9e40d41de1c73..ec581ac7277e3 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -129,6 +129,5 @@ add_mlir_dialect_library(MLIRXeVMDialect
MLIRDialectUtils
MLIRIR
MLIRLLVMDialect
- MLIRXeGPUuArch
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 6af69bb549bea..04e8836c00359 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -8,7 +8,6 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 9079df050ab2b..31167e6af908b 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,4 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
-add_subdirectory(uArch)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index e1c51deefbee3..7869a28dfed57 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -20,7 +20,6 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRGPUDialect
MLIRXeVMDialect
MLIRIR
- MLIRXeGPUuArch
MLIRViewLikeInterface
MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e56fe3a1e193d..e6f76067094ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -23,7 +23,6 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRTransforms
MLIRGPUDialect
MLIRXeGPUUtils
- MLIRXeGPUuArch
MLIRGPUUtils
MLIRVectorTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
deleted file mode 100644
index c7f691cb6dda7..0000000000000
--- a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_dialect_library(MLIRXeGPUuArch
- IntelGpuXe2.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRDialectUtils
-)
-
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
deleted file mode 100644
index e19df1e5235ba..0000000000000
--- a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
+++ /dev/null
@@ -1,172 +0,0 @@
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "llvm/Support/DebugLog.h"
-#include <algorithm>
-#include <vector>
-
-#define DEBUG_TYPE "xegpu-uarch"
-
-using namespace mlir::xegpu::uArch;
-using namespace mlir::xegpu::uArch::Xe2Plus;
-
-namespace mlir {
-namespace xegpu {
-namespace uArch {
-namespace Xe2Plus {
-
-std::vector<std::pair<uint32_t, uint32_t>>
-DPASInstruction::getSupportedShapes(mlir::Type dataType,
- MMAOpndKind matrixType) {
- auto combineVectors = [](const std::vector<uint32_t> &a,
- const std::vector<uint32_t> &b)
- -> std::vector<std::pair<uint32_t, uint32_t>> {
- std::vector<std::pair<uint32_t, uint32_t>> result;
- for (unsigned x : a) {
- for (unsigned y : b) {
- result.emplace_back(x, y);
- }
- }
- return result;
- };
-
- auto M = getSupportedM(dataType);
- auto K = getSupportedK(dataType);
- auto N = getSupportedN(dataType);
- std::vector<std::pair<unsigned, unsigned>> resultMatrix;
-
- switch (matrixType) {
- case MMAOpndKind::MatrixA:
- resultMatrix = combineVectors(M, K);
- break;
- case MMAOpndKind::MatrixB:
- resultMatrix = combineVectors(K, N);
- break;
- case MMAOpndKind::MatrixC:
- resultMatrix = combineVectors(M, N);
- break;
- case MMAOpndKind::MatrixD:
- resultMatrix = combineVectors(M, N);
- break;
- }
- return resultMatrix;
-}
-
-std::vector<mlir::Type>
-DPASInstruction::getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) {
- mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
- mlir::Type f16Type = mlir::Float16Type::get(&context);
- mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
- mlir::Type f32Type = mlir::Float32Type::get(&context);
-
- switch (matrixType) {
- case MMAOpndKind::MatrixA:
- return {bf16Type, f16Type, tf32Type};
- break;
- case MMAOpndKind::MatrixB:
- return {bf16Type, f16Type, tf32Type};
- break;
- case MMAOpndKind::MatrixC:
- return {bf16Type, f16Type, f32Type};
- break;
- case MMAOpndKind::MatrixD:
- return {bf16Type, f16Type, f32Type};
- break;
- }
-}
-
-bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) {
- if (AType.isF16() || BType.isF16()) {
- if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
- (!DType.isF32() && !DType.isF16())) {
- LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
- return false;
- }
- } else if (AType.isBF16() || BType.isBF16()) {
- if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
- (!DType.isF32() && !DType.isBF16())) {
- LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
- return false;
- }
- } else if (AType.isTF32() || BType.isTF32()) {
- if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
- (!DType.isF32())) {
- LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
- return false;
- }
- } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
- AType.isInteger(8)) &&
- !(BType.isInteger(2) || BType.isInteger(4) ||
- BType.isInteger(8))) {
- LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
- return false;
- }
-
- return true;
-}
-
-bool DPASInstruction::checkSupportedShapesAndTypes(
- std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
- mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
- auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
- auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
- auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
- auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
- return llvm::is_contained(supportedAShapes, AShape) &&
- llvm::is_contained(supportedBShapes, BShape) &&
- llvm::is_contained(supportedCShapes, CShape) &&
- llvm::is_contained(supportedDShapes, DShape) &&
- checkSupportedTypes(AType, BType, CType, DType);
-}
-
-bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
- std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape,
- mlir::Type AType, mlir::Type BType,
- mlir::Type CType, mlir::Type DType) {
- return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
- BType, CType, DType);
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedM(mlir::Type type) {
- return {1, 2, 3, 4, 5, 6, 7, 8};
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedK(mlir::Type type) {
- // assert if data type is not int or float type
- assert(type.isIntOrFloat() && "Matrix type must be int or float");
- auto bitWidth = type.getIntOrFloatBitWidth();
- uint32_t kSize = 0;
- switch (bitWidth) {
- case 2:
- kSize = 64;
- break;
- case 4:
- kSize = 64;
- break;
- case 8:
- kSize = 32;
- break;
- case 16:
- kSize = 16;
- break;
- case 32:
- kSize = 8;
- break;
- default:
- llvm_unreachable("Invalid int or float");
- }
- return {kSize};
-}
-
-std::vector<uint32_t> DPASInstruction::getSupportedN(mlir::Type type) {
- return {16};
-}
-
-} // namespace Xe2Plus
-} // namespace uArch
-} // namespace xegpu
-} // namespace mlir
>From 6098788ed6671769487026d38dbaf8ebd6ff6a3e Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 15:49:17 +0000
Subject: [PATCH 06/13] Address review comments.
Use LLVM data structures whenever possible.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 73 ++++++++--------
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 85 ++++++++++++-------
2 files changed, 88 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 7b1381e9efc2d..453dc57a72020 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -18,10 +18,10 @@
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
#include <map>
#include <string>
-#include <vector>
#define DEBUG_TYPE "xegpu-uarch"
@@ -47,12 +47,12 @@ struct XeCoreInfo {
struct Xe2Plus : public uArch {
XeCoreInfo xe_core;
- Xe2Plus(
- const std::string &archName, const std::string &archDescription,
- const XeCoreInfo &xeCore,
- const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
- const std::vector<CacheInfo> &cacheInfo = {},
- const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+ Xe2Plus(const std::string &archName, const std::string &archDescription,
+ const XeCoreInfo &xeCore,
+ const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
+ const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
+ &instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xe_core(xeCore) {}
};
@@ -60,15 +60,16 @@ struct Xe2Plus : public uArch {
// struct to represent DPAS instruction
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
DPASInstruction()
- : Instruction("dpas", // name
- "Dot Product Accumulate") // description
+ : Instruction(InstructionKind::DPAS, // name
+ "Dot Product Accumulate",
+ InstructionScope::Subgroup) // description
{}
// Override all virtuals from MatrixOpInterface
- virtual std::vector<std::pair<uint32_t, uint32_t>>
+ virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
- virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) override;
+ virtual llvm::SmallVector<Type, 8>
+ getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
virtual bool
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
@@ -82,14 +83,14 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) override;
- virtual std::vector<uint32_t> getSupportedM(Type type) override;
- virtual std::vector<uint32_t> getSupportedK(Type type) override;
- virtual std::vector<uint32_t> getSupportedN(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
};
struct PVCuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
- std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
@@ -115,17 +116,16 @@ struct PVCuArch : public Xe2Plus {
this->cacheInfo.push_back(
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
- // Add the instructions
+ // Add the instructions-
auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getName(), dpas);
- // instructions[dpas->name] = dpas.get();
+ instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
}
};
struct BMGuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
- std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
@@ -151,8 +151,7 @@ struct BMGuArch : public Xe2Plus {
// Add the instructions
auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getName(), dpas);
- // instructions[dpas->name] = dpas.get();
+ instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
}
};
@@ -160,12 +159,12 @@ struct BMGuArch : public Xe2Plus {
} // namespace xegpu
} // namespace mlir
-inline std::vector<std::pair<uint32_t, uint32_t>>
+inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
- auto combineVectors = [](const std::vector<uint32_t> &a,
- const std::vector<uint32_t> &b)
- -> std::vector<std::pair<uint32_t, uint32_t>> {
- std::vector<std::pair<uint32_t, uint32_t>> result;
+ auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
+ const llvm::SmallVector<uint32_t, 8> &b)
+ -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
+ llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
for (unsigned x : a) {
for (unsigned y : b) {
result.emplace_back(x, y);
@@ -177,7 +176,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto M = getSupportedM(dataType);
auto K = getSupportedK(dataType);
auto N = getSupportedN(dataType);
- std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+ llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
switch (matrixType) {
case MMAOpndKind::MatrixA:
@@ -196,7 +195,7 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
return resultMatrix;
}
-inline std::vector<Type>
+inline llvm::SmallVector<Type, 8>
DPASInstruction::getSupportedTypes(MLIRContext &context,
MMAOpndKind matrixType) {
Type bf16Type = BFloat16Type::get(&context);
@@ -207,17 +206,14 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
switch (matrixType) {
case MMAOpndKind::MatrixA:
return {bf16Type, f16Type, tf32Type};
- break;
case MMAOpndKind::MatrixB:
return {bf16Type, f16Type, tf32Type};
- break;
case MMAOpndKind::MatrixC:
return {bf16Type, f16Type, f32Type};
- break;
case MMAOpndKind::MatrixD:
return {bf16Type, f16Type, f32Type};
- break;
}
+ return {};
}
inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
@@ -276,11 +272,13 @@ inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
BType, CType, DType);
}
-inline std::vector<uint32_t> DPASInstruction::getSupportedM(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedM(Type type) {
return {1, 2, 3, 4, 5, 6, 7, 8};
}
-inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedK(Type type) {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -307,8 +305,9 @@ inline std::vector<uint32_t> DPASInstruction::getSupportedK(Type type) {
return {kSize};
}
-inline std::vector<uint32_t> DPASInstruction::getSupportedN(Type type) {
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedN(Type type) {
return {16};
}
-#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2H
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 7c76157fa2a1a..fe0e0f6528042 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -23,6 +23,7 @@
#include <tuple>
#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace xegpu {
@@ -31,12 +32,26 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
-enum class InstructionName {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix multiply-add
- // operation
+enum class InstructionKind {
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+ // multiply-add operation
// Add more instructions as needed
};
+llvm::StringRef toString(InstructionKind name) {
+ switch (name) {
+ case InstructionKind::DPAS:
+ return "dpas";
+ }
+ llvm_unreachable("Unknown InstructionKind");
+}
+
+std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
+ if (str.equals_insensitive("dpas"))
+ return InstructionKind::DPAS;
+ return std::nullopt;
+}
+
// A struct to represent basic information about an instruction
// This struct is used to represent the information about an instruction in the
// uArch The information includes:
@@ -56,17 +71,17 @@ enum class InstructionName {
struct Instruction {
// @TODO: Add more fields as needed
- Instruction(std::string name, std::string desc)
- : name(std::move(name)), description(std::move(desc)) {}
+ Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
+ : instKind(kind), description(std::move(desc)), scope(scope) {}
virtual ~Instruction() = default;
// Get methods
- std::string getName() { return name; }
+ InstructionKind getInstructionKind() { return instKind; }
std::string getDescription() { return description; }
InstructionScope getScope() { return scope; }
protected:
- std::string name;
+ InstructionKind instKind;
std::string description;
InstructionScope scope;
};
@@ -78,23 +93,25 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
struct RegisterFileInfo {
// Constructor
RegisterFileInfo() = default;
- RegisterFileInfo(uint32_t size, const std::vector<RegisterFileMode> &mode,
- const std::vector<uint32_t> &numRegs)
+ RegisterFileInfo(uint32_t size,
+ const llvm::SmallVector<RegisterFileMode, 4> &mode,
+ const llvm::SmallVector<uint32_t, 4> &numRegs)
: size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
uint32_t getSize() const { return size; }
- const std::vector<RegisterFileMode> &getModes() const { return mode; }
- const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+ const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
+ return mode;
+ }
+ const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
return numRegsPerThreadPerMode;
}
protected:
- uint32_t size; // size per register in bits
- std::vector<RegisterFileMode> mode; // e.g., "small", "large" GRF modes
- std::vector<uint32_t>
+ uint32_t size; // size per register in bits
+ llvm::SmallVector<RegisterFileMode, 4>
+ mode; // e.g., "small", "large" GRF modes
+ llvm::SmallVector<uint32_t, 4>
numRegsPerThreadPerMode; // number of registers per thread per mode
- // TODO: Add more fields as needed (e.g., num_banks, bank_size, num_ports,
- // port_width, bank_conflicts)
};
enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
@@ -136,8 +153,8 @@ struct uArch {
uArch(const std::string &name, const std::string &description,
const std::map<RegisterFileType, RegisterFileInfo> ®ister_file_info =
{},
- const std::vector<CacheInfo> &cache_info = {},
- const std::map<std::string, std::shared_ptr<Instruction>>
+ const llvm::SmallVector<CacheInfo, 4> &cache_info = {},
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
&instructions = {})
: name(name), description(description),
registerFileInfo(register_file_info), cacheInfo(cache_info),
@@ -153,34 +170,36 @@ struct uArch {
return registerFileInfo;
}
- const std::vector<CacheInfo> &getCacheInfo() const { return cacheInfo; }
+ const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
+ return cacheInfo;
+ }
- const std::map<std::string, std::shared_ptr<Instruction>> &
+ const std::map<InstructionKind, std::shared_ptr<Instruction>> &
getInstructions() const {
return instructions;
}
// Get the name of the supported instruction names for that
// architecture. It returns the names of the instructions added to the uArch.
- std::vector<std::string> getSupportedInstructionNames() const {
- std::vector<std::string> instructionNames;
+ llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
+ llvm::SmallVector<StringRef, 8> instructionNames;
for (const auto &inst : instructions) {
- instructionNames.push_back(inst.first);
+ instructionNames.push_back(toString(inst.first));
}
return instructionNames;
}
// Checks if an instruction is supported in this uArch
- bool checkSupportedInstruction(const std::string &instructionName) const {
- return instructions.find(instructionName) != instructions.end();
+ bool checkSupportedInstruction(InstructionKind instr) const {
+ return instructions.find(instr) != instructions.end();
}
protected:
std::string name; // Similar to target triple
std::string description;
std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
- std::vector<CacheInfo> cacheInfo;
- std::map<std::string, std::shared_ptr<Instruction>> instructions;
+ llvm::SmallVector<CacheInfo, 4> cacheInfo;
+ std::map<InstructionKind, std::shared_ptr<Instruction>> instructions;
};
// A struct to represent shared memory information
@@ -205,7 +224,7 @@ struct SharedMemory {
enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
struct MMAInstructionInterface {
// Get supported Matrix shapes
- virtual std::vector<std::pair<uint32_t, uint32_t>>
+ virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
// @TODO: This method takes an context object as a parameter, this is to
// create the Type objects from the same context. Since type objects are
@@ -220,8 +239,8 @@ struct MMAInstructionInterface {
//
// Untill we have a better solution, we stick to passing context object to
// this method.
- virtual std::vector<Type> getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) = 0;
+ virtual llvm::SmallVector<Type, 8>
+ getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
virtual bool
checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
std::pair<uint32_t, uint32_t> BShape,
@@ -235,9 +254,9 @@ struct MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) = 0;
- virtual std::vector<uint32_t> getSupportedM(Type type) = 0;
- virtual std::vector<uint32_t> getSupportedK(Type type) = 0;
- virtual std::vector<uint32_t> getSupportedN(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
virtual ~MMAInstructionInterface() = default;
};
>From 8fb4f1794d309b8d9385d2fb468148a38f6161a0 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 15:58:40 +0000
Subject: [PATCH 07/13] Add/Remove some spacings.
---
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index fe0e0f6528042..8f38b2afede4d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -31,7 +31,6 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
-
enum class InstructionKind {
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
// multiply-add operation
@@ -68,7 +67,6 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
// represent information about an instruction and to use this information to
// generate the uArch. Specifc instruction in a uArch can inherit from this
// struct and add more fields as needed
-
struct Instruction {
// @TODO: Add more fields as needed
Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
@@ -115,6 +113,7 @@ struct RegisterFileInfo {
};
enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
+
// A struct to represent cache information
struct CacheInfo {
// Constructor
>From 31b93d6c2daee68aab6011fa56ff657b500c2e4a Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 25 Sep 2025 16:38:48 +0000
Subject: [PATCH 08/13] Address review comments.
---
mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 6 +-----
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 7 +++----
2 files changed, 4 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 453dc57a72020..19c9138055b27 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -37,7 +37,6 @@ struct XeCoreInfo {
uint32_t num_vector_units;
uint32_t num_matrix_units;
- // Constructor
XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
uint32_t num_vector_units, uint32_t num_matrix_units)
: num_threads(num_threads), shared_memory(shared_memory),
@@ -60,10 +59,7 @@ struct Xe2Plus : public uArch {
// struct to represent DPAS instruction
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
DPASInstruction()
- : Instruction(InstructionKind::DPAS, // name
- "Dot Product Accumulate",
- InstructionScope::Subgroup) // description
- {}
+ : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f38b2afede4d..69db5d5b5367f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -69,13 +69,12 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
// struct and add more fields as needed
struct Instruction {
// @TODO: Add more fields as needed
- Instruction(InstructionKind kind, std::string desc, InstructionScope scope)
- : instKind(kind), description(std::move(desc)), scope(scope) {}
+ Instruction(InstructionKind kind, InstructionScope scope)
+ : instKind(kind), scope(scope) {}
virtual ~Instruction() = default;
// Get methods
InstructionKind getInstructionKind() { return instKind; }
- std::string getDescription() { return description; }
InstructionScope getScope() { return scope; }
protected:
@@ -143,7 +142,7 @@ struct CacheInfo {
// - the name of the uArch,
// - the description of the uArch,
// - uArch hierarchy
-// - Rgister File information
+// - Register File information
// - Cache information
// - the set of instructions supported by the uArch,
struct uArch {
>From fa6c4bcfb31746dac05a15be0fb7071e15585639 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:09:33 +0000
Subject: [PATCH 09/13] Address review comments.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 28 +++-----
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 70 +++++++++----------
2 files changed, 41 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 19c9138055b27..0519f7b2e277d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -31,21 +31,9 @@ using namespace mlir::xegpu::uArch;
namespace mlir {
namespace xegpu {
namespace uArch {
-struct XeCoreInfo {
- uint32_t num_threads;
- SharedMemory shared_memory;
- uint32_t num_vector_units;
- uint32_t num_matrix_units;
-
- XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
- uint32_t num_vector_units, uint32_t num_matrix_units)
- : num_threads(num_threads), shared_memory(shared_memory),
- num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
- }
-};
struct Xe2Plus : public uArch {
- XeCoreInfo xe_core;
+ XeCoreInfo xeCore;
Xe2Plus(const std::string &archName, const std::string &archDescription,
const XeCoreInfo &xeCore,
const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
@@ -53,7 +41,7 @@ struct Xe2Plus : public uArch {
const std::map<InstructionKind, std::shared_ptr<Instruction>>
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
- xe_core(xeCore) {}
+ xeCore(xeCore) {}
};
// struct to represent DPAS instruction
@@ -91,9 +79,9 @@ struct PVCuArch : public Xe2Plus {
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
- {/* register_file_info */}, // Optional: empty
- {/* cache_info */}, // Optional: empty
- {/* instructions */} // Optional: empty
+ {/* registerFileInfo */}, // Optional: empty
+ {/* cacheInfo */}, // Optional: empty
+ {/* instructions */} // Optional: empty
) {
// Intialize register file info
// GRF
@@ -126,9 +114,9 @@ struct BMGuArch : public Xe2Plus {
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
- {/* register_file_info */}, // Optional: empty
- {/* cache_info */}, // Optional: empty
- {/* instructions */} // Optional: empty)
+ {/* registerFileInfo */}, // Optional: empty
+ {/* cacheInfo */}, // Optional: empty
+ {/* instructions */} // Optional: empty
) {
// Intialize register file info
// GRF
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 69db5d5b5367f..bb47a3e138e7b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -1,4 +1,4 @@
-//===--- uArch.h ------------------------------------------------*- C++ -*-===//
+//===- uArch.h --------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -34,7 +34,7 @@ enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
// multiply-add operation
- // Add more instructions as needed
+ // @TODO: Add more instructions as needed
};
llvm::StringRef toString(InstructionKind name) {
@@ -51,24 +51,12 @@ std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
return std::nullopt;
}
-// A struct to represent basic information about an instruction
-// This struct is used to represent the information about an instruction in the
-// uArch The information includes:
-// - the name of the instruction,
-// - the description of the instruction
-// - the scope of the instruction,
-//
-// The information is represented as strings
-// For example, the information about an instruction can be represented as:
-// Instruction instr = {"dpas", "Dot Product Accumulate Systolic (DPAS) is a
-// matrix multiply-add operation", "subgroup"};
-
+// A struct to represent basic information about an instruction.
// The primary purpose of the Instruction struct is to provide a generic way to
// represent information about an instruction and to use this information to
// generate the uArch. Specifc instruction in a uArch can inherit from this
-// struct and add more fields as needed
+// struct and add more fields as needed.
struct Instruction {
- // @TODO: Add more fields as needed
Instruction(InstructionKind kind, InstructionScope scope)
: instKind(kind), scope(scope) {}
@@ -79,8 +67,8 @@ struct Instruction {
protected:
InstructionKind instKind;
- std::string description;
InstructionScope scope;
+ // @TODO: Add more fields as needed
};
enum class RegisterFileMode : uint8_t { Small, Large };
@@ -89,16 +77,18 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
// A struct to represent register file information
struct RegisterFileInfo {
// Constructor
- RegisterFileInfo() = default;
RegisterFileInfo(uint32_t size,
const llvm::SmallVector<RegisterFileMode, 4> &mode,
const llvm::SmallVector<uint32_t, 4> &numRegs)
: size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
+ // Get methods
uint32_t getSize() const { return size; }
+
const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
return mode;
}
+
const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
return numRegsPerThreadPerMode;
}
@@ -137,25 +127,17 @@ struct CacheInfo {
};
// A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device
-// The uArch includes:
-// - the name of the uArch,
-// - the description of the uArch,
-// - uArch hierarchy
-// - Register File information
-// - Cache information
-// - the set of instructions supported by the uArch,
+// This struct is used to represent the microarchitecture of a target device.
struct uArch {
// Constructor
- uArch() = default;
- uArch(const std::string &name, const std::string &description,
- const std::map<RegisterFileType, RegisterFileInfo> ®ister_file_info =
- {},
- const llvm::SmallVector<CacheInfo, 4> &cache_info = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instructions = {})
+ uArch(
+ const std::string &name, const std::string &description,
+ const std::map<RegisterFileType, RegisterFileInfo> ®isterFileInfo = {},
+ const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
+ &instructions = {})
: name(name), description(description),
- registerFileInfo(register_file_info), cacheInfo(cache_info),
+ registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
instructions(instructions) {}
// Get methods
@@ -193,11 +175,12 @@ struct uArch {
}
protected:
- std::string name; // Similar to target triple
+ std::string name; // Name of the uArch, similar to target triple
std::string description;
std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
llvm::SmallVector<CacheInfo, 4> cacheInfo;
- std::map<InstructionKind, std::shared_ptr<Instruction>> instructions;
+ std::map<InstructionKind, std::shared_ptr<Instruction>>
+ instructions; // set of instructions supported by the uArch
};
// A struct to represent shared memory information
@@ -206,7 +189,7 @@ struct SharedMemory {
SharedMemory(uint32_t size, uint32_t alignment)
: size(size), alignment(alignment) {}
- // Getters
+ // Get methods
uint32_t getSize() const { return size; }
uint32_t getAlignment() const { return alignment; }
@@ -216,6 +199,19 @@ struct SharedMemory {
// @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
};
+struct XeCoreInfo {
+ uint32_t num_threads;
+ SharedMemory shared_memory;
+ uint32_t num_vector_units;
+ uint32_t num_matrix_units;
+
+ XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+ uint32_t num_vector_units, uint32_t num_matrix_units)
+ : num_threads(num_threads), shared_memory(shared_memory),
+ num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+ }
+};
+
//===----------------------------------------------------------------------===//
// Interfaces
//===----------------------------------------------------------------------===//
>From 2fc129a8444fb698b7e79071dc7844d9723d8218 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:13:37 +0000
Subject: [PATCH 10/13] Address review comments.
---
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index bb47a3e138e7b..64aa25ffd543f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -30,7 +30,7 @@ namespace xegpu {
namespace uArch {
// An enum class to represent the scope of an instruction
-enum class InstructionScope { WorkItem, Subgroup, Workgroup, Cluster };
+enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
// multiply-add operation
@@ -66,8 +66,9 @@ struct Instruction {
InstructionScope getScope() { return scope; }
protected:
- InstructionKind instKind;
- InstructionScope scope;
+ InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+ InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
+ // workgroup, cluster)
// @TODO: Add more fields as needed
};
>From 28903cb8d738be3a0c570f5cc60e1fd4ca57966b Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 17:53:10 +0000
Subject: [PATCH 11/13] Fix a small compile error.
---
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 64aa25ffd543f..48d2302994592 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -78,6 +78,7 @@ enum class RegisterFileType : uint8_t { GRF, ARF };
// A struct to represent register file information
struct RegisterFileInfo {
// Constructor
+ RegisterFileInfo() = default;
RegisterFileInfo(uint32_t size,
const llvm::SmallVector<RegisterFileMode, 4> &mode,
const llvm::SmallVector<uint32_t, 4> &numRegs)
@@ -107,6 +108,7 @@ enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
// A struct to represent cache information
struct CacheInfo {
// Constructor
+ CacheInfo() = default;
CacheInfo(uint32_t size, uint32_t line_size,
CacheHierarchyLevel hierarchy_level)
: size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
>From 6ef33aec53a11f1eaa3191be3d5fa26d1099232c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 8 Oct 2025 12:50:31 +0000
Subject: [PATCH 12/13] Address review comments.
---
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 48d2302994592..7341bbb0ed638 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -37,7 +37,7 @@ enum class InstructionKind {
// @TODO: Add more instructions as needed
};
-llvm::StringRef toString(InstructionKind name) {
+static llvm::StringRef toString(InstructionKind name) {
switch (name) {
case InstructionKind::DPAS:
return "dpas";
@@ -45,7 +45,8 @@ llvm::StringRef toString(InstructionKind name) {
llvm_unreachable("Unknown InstructionKind");
}
-std::optional<InstructionKind> parseInstructionKind(llvm::StringRef str) {
+static std::optional<InstructionKind>
+parseInstructionKind(llvm::StringRef str) {
if (str.equals_insensitive("dpas"))
return InstructionKind::DPAS;
return std::nullopt;
>From 1ffe77eb18c25070a4994701be8fa4365b26548c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 8 Oct 2025 13:25:32 +0000
Subject: [PATCH 13/13] Address review comments.
Make the some helper funcs to static methods.
---
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 31 +++++++++----------
1 file changed, 15 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 7341bbb0ed638..955994ea5ecf5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -37,21 +37,6 @@ enum class InstructionKind {
// @TODO: Add more instructions as needed
};
-static llvm::StringRef toString(InstructionKind name) {
- switch (name) {
- case InstructionKind::DPAS:
- return "dpas";
- }
- llvm_unreachable("Unknown InstructionKind");
-}
-
-static std::optional<InstructionKind>
-parseInstructionKind(llvm::StringRef str) {
- if (str.equals_insensitive("dpas"))
- return InstructionKind::DPAS;
- return std::nullopt;
-}
-
// A struct to represent basic information about an instruction.
// The primary purpose of the Instruction struct is to provide a generic way to
// represent information about an instruction and to use this information to
@@ -65,6 +50,20 @@ struct Instruction {
// Get methods
InstructionKind getInstructionKind() { return instKind; }
InstructionScope getScope() { return scope; }
+ static llvm::StringRef toString(InstructionKind instKind) {
+ switch (instKind) {
+ case InstructionKind::DPAS:
+ return "dpas";
+ }
+ llvm_unreachable("Unknown InstructionKind");
+ }
+
+ static std::optional<InstructionKind>
+ parseInstructionKind(llvm::StringRef str) {
+ if (str.equals_insensitive("dpas"))
+ return InstructionKind::DPAS;
+ return std::nullopt;
+ }
protected:
InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
@@ -168,7 +167,7 @@ struct uArch {
llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
llvm::SmallVector<StringRef, 8> instructionNames;
for (const auto &inst : instructions) {
- instructionNames.push_back(toString(inst.first));
+ instructionNames.push_back(Instruction::toString(inst.first));
}
return instructionNames;
}
More information about the Mlir-commits
mailing list