[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Thu Aug 14 15:42:21 PDT 2025
https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/153706
>From b1d37c0db0c1a308feedf981218c04516815e6b4 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Fri, 28 Mar 2025 13:04:30 +0000
Subject: [PATCH] [uArch][XeGPU] Add XeGPU uArch definition.
The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary components
(e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops
(e.g., mma ops, 2DBlockIO ops). The implementation of these interfaces
are provided by the specific instructions. Each family of ops provides
these 5 common APIs. However, some family of ops may have more
utility APIs. The common 5 APIs are:
- getSupportedShapes
- getSupportedTypes
- checkSupportedShapesAndTypes
- checkSupportedTypes
- validate
Add support for PVC and BMG architectures.
Add support for DPAS instruction.
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 182 ++++++++++++
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 266 ++++++++++++++++++
.../Dialect/XeGPU/uArch/uArchInterfaces.h | 75 +++++
mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 1 +
mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 9 +
mlir/lib/Dialect/XeGPU/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 9 +
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt | 11 +
mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp | 197 +++++++++++++
11 files changed, 753 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
create mode 100644 mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
create mode 100644 mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..9179838f8c148
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,182 @@
+//===--- IntelGpuXe2.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Xe2 uArch definition.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+struct XeCoreInfo {
+ uint32_t num_threads;
+ SharedMemory shared_memory;
+ uint32_t num_vector_units;
+ uint32_t num_matrix_units;
+
+ // Constructor
+ XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+ uint32_t num_vector_units, uint32_t num_matrix_units)
+ : num_threads(num_threads), shared_memory(shared_memory),
+ num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+ }
+};
+
+struct Xe2Plus : public uArch {
+ XeCoreInfo xe_core;
+ Xe2Plus(
+ const std::string &archName, const std::string &archDescription,
+ const XeCoreInfo &xeCore,
+ const std::vector<uArchHierarchyComponent> &hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®Info = {},
+ const std::vector<CacheInfo> &cacheInfo = {},
+ const std::map<std::string, std::shared_ptr<Instruction>> &instrs = {})
+ : uArch(archName, archDescription, hierarchy, regInfo, cacheInfo, instrs),
+ xe_core(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+ DPASInstruction()
+ : Instruction("dpas", // name
+ "Dot Product Accumulate") // description
+ {}
+
+ // Override all virtuals from MatrixOpInterface
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) override;
+ virtual std::vector<mlir::Type>
+ getSupportedTypes(MLIRContext &context, MMAOpndEnum matrixType) override;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) override;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+ mlir::Type BType, mlir::Type CType,
+ mlir::Type DType) override;
+ virtual std::vector<uint32_t> getSupportedM(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedK(mlir::Type type) override;
+ virtual std::vector<uint32_t> getSupportedN(mlir::Type type) override;
+};
+
+namespace PVCuArch {
+struct PVCuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ PVCuArch()
+ : Xe2Plus("pvc", // archName
+ "Ponte Vecchio Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 16));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 2));
+ // Intialize register file info
+ // GRF
+ this->register_file_info.emplace(
+ "GRF",
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ ));
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(512 * 1024, 64, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace PVCuArch
+
+namespace BMGuArch {
+struct BMGuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ std::vector<std::shared_ptr<Instruction>> owned_instructions;
+ BMGuArch()
+ : Xe2Plus("bmg", // archName
+ "Battlemage Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+ {/* register_file_info */}, // Optional: empty
+ {/* cache_info */}, // Optional: empty
+ {/* instructions */}, // Optional: empty
+ {/* restrictions */} // Optional: empty
+ ) {
+ // Initialize uArchHierarchy
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("thread", 0));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeCore", 8));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeSlice", 4));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("XeStack", 5));
+ this->uArch_hierarchy.push_back(uArchHierarchyComponent("gpu", 1));
+ // Intialize register file info
+ // GRF
+ this->register_file_info["GRF"] =
+ RegisterFileInfo(64 * 1024, // size in bits
+ {"small", "large"}, // GRF modes
+ {128, 256}, // registers per thread per mode
+ 0, // number of banks
+ 0 // bank size
+ );
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cache_info.push_back(
+ CacheInfo(256 * 1024, 64, this->uArch_hierarchy[1]));
+ // L3 cache, XeStack level
+ this->cache_info.push_back(
+ CacheInfo(18 * 1024 * 1024, 256, this->uArch_hierarchy[3]));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getName(), dpas);
+ // instructions[dpas->name] = dpas.get();
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace BMGuArch
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTEL_GPU_XE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..9bda86df2aff9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,266 @@
+//===--- uArch.h ---------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Base uArch definition for different architectures.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_BASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_BASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+// Architecture HW component hierarchy to present thread, core, socket ...
+struct uArchHierarchyComponent {
+ std::string name = ""; // optional name of the hierarchy component
+ // no. of lower hierarchy component it contains, e.g., for PVC XeCore it
+ // contains 8 threads, so no_of_component=8
+ uint32_t no_of_component;
+ // Constructor
+ uArchHierarchyComponent(const std::string &name, uint32_t no_of_component)
+ : name(name), no_of_component(no_of_component) {}
+};
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScopeEnum { WorkItem, Subgroup, Workgroup, Cluster };
+
+// A struct to represent basic information about an instruction
+// This struct is used to represent the information about an instruction in the
+// uArch The information includes:
+// - the name of the instruction,
+// - the description of the instruction
+// - the scope of the instruction,
+//
+// The information is represented as strings
+// For example, the information about an instruction can be represented as:
+// Instruction instr = {"dpas", "Dot Product Accumulate Systolic (DPAS) is a
+// matrix multiply-add operation", "subgroup"};
+
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed
+
+struct Instruction {
+ // @TODO: Add more fields as needed
+ Instruction(std::string name, std::string desc)
+ : name(std::move(name)), description(std::move(desc)) {}
+
+ virtual ~Instruction() = default;
+ // Get methods
+ std::string getName() { return name; }
+ std::string getDescription() { return description; }
+ InstructionScopeEnum getScope() { return scope; }
+
+protected:
+ std::string name;
+ std::string description;
+ InstructionScopeEnum scope;
+};
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+ // Constructor
+ RegisterFileInfo() = default;
+ RegisterFileInfo(uint32_t size, const std::vector<std::string> &mode,
+ const std::vector<uint32_t> &numRegs, uint32_t num_banks,
+ uint32_t bank_size)
+ : size(size), mode(mode), num_regs_per_thread_per_mode(numRegs),
+ num_banks(num_banks), bank_size(bank_size) {}
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+
+ const std::vector<std::string> &getModes() const { return mode; }
+
+ const std::vector<uint32_t> &getNumRegsPerThreadPerMode() const {
+ return num_regs_per_thread_per_mode;
+ }
+
+ uint32_t getNumBanks() const { return num_banks; }
+
+ uint32_t getBankSize() const { return bank_size; }
+
+protected:
+ uint32_t size; // size per register in bits
+ std::vector<std::string> mode; // e.g., "small", "large" GRF modes
+ std::vector<uint32_t>
+ num_regs_per_thread_per_mode; // number of registers per thread per mode
+ uint32_t num_banks;
+ uint32_t bank_size;
+};
+
+// A struct to represent cache information
+
+struct CacheInfo {
+ // Constructor
+ CacheInfo(uint32_t size, uint32_t line_size,
+ const uArchHierarchyComponent &component)
+ : size(size), line_size(line_size), component(component) {}
+
+ virtual ~CacheInfo() = default;
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+ uint32_t getLineSize() const { return line_size; }
+ const uArchHierarchyComponent &getComponent() const { return component; }
+
+protected:
+ uint32_t size;
+ uint32_t line_size;
+ // At which component level the cache is shared
+ uArchHierarchyComponent component;
+
+ // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+ // bank_size, num_ports, port_width, bank_conflicts)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device
+// The uArch includes:
+// - the name of the uArch,
+// - the description of the uArch,
+// - uArch hierarchy
+// - Rgister File information
+// - Cache information
+// - the set of instructions supported by the uArch,
+struct uArch {
+ // Constructor
+ uArch() = default;
+ uArch(const std::string &name, const std::string &description,
+ const std::vector<uArchHierarchyComponent> &uArch_hierarchy = {},
+ const std::map<std::string, RegisterFileInfo> ®ister_file_info = {},
+ const std::vector<CacheInfo> &cache_info = {},
+ const std::map<std::string, std::shared_ptr<Instruction>>
+ &instructions = {})
+ : name(name), description(description), uArch_hierarchy(uArch_hierarchy),
+ register_file_info(register_file_info), cache_info(cache_info),
+ instructions(instructions) {}
+
+ // Get methods
+ const std::string &getName() const { return name; }
+
+ const std::string &getDescription() const { return description; }
+
+ const std::vector<uArchHierarchyComponent> &getHierarchy() const {
+ return uArch_hierarchy;
+ }
+
+ const std::map<std::string, RegisterFileInfo> &getRegisterFileInfo() const {
+ return register_file_info;
+ }
+
+ const std::vector<CacheInfo> &getCacheInfo() const { return cache_info; }
+
+ const std::map<std::string, std::shared_ptr<Instruction>> &
+ getInstructions() const {
+ return instructions;
+ }
+
+ // Get the name of the supported instruction names for that
+ // architecture. It returns the names of the instructions added to the uArch.
+ std::vector<std::string> getSupportedInstructionNames() const {
+ std::vector<std::string> instructionNames;
+ for (const auto &inst : instructions) {
+ instructionNames.push_back(inst.first);
+ }
+ return instructionNames;
+ }
+
+ // Checks if an instruction is supported in this uArch
+ bool checkSupportedInstruction(const std::string &instructionName) const {
+ return instructions.find(instructionName) != instructions.end();
+ }
+
+protected:
+ std::string name; // Similar to target triple
+ std::string description;
+ std::vector<uArchHierarchyComponent> uArch_hierarchy;
+ std::map<std::string, RegisterFileInfo> register_file_info;
+ std::vector<CacheInfo> cache_info;
+ std::map<std::string, std::shared_ptr<Instruction>> instructions;
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+ // Constructor
+ SharedMemory(uint32_t size, uint32_t alignment)
+ : size(size), alignment(alignment) {}
+
+ // Getters
+ uint32_t getSize() const { return size; }
+ uint32_t getAlignment() const { return alignment; }
+
+protected:
+ uint32_t size; // in bytes
+ uint32_t alignment; // in bytes
+ // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct uArchMap {
+public:
+ // Singleton instance
+ static uArchMap &instance() {
+ static uArchMap instance;
+ return instance;
+ }
+
+ // Insert or update a key-value pair
+ void insert(const std::string &key, std::shared_ptr<uArch> value) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ // map_[key] = std::move(value); // safe to overwrite
+ map_.emplace(key, std::move(value)); // safe to overwrite
+ }
+
+ // Get a value by key (concurrent safe read)
+ std::shared_ptr<uArch> get(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ auto it = map_.find(key);
+ if (it != map_.end())
+ return it->second;
+ return nullptr;
+ }
+
+ // Check if a key exists
+ bool contains(const std::string &key) const {
+ std::shared_lock<std::shared_mutex> lock(mutex_);
+ return map_.find(key) != map_.end();
+ }
+
+ // Remove a key
+ bool erase(const std::string &key) {
+ std::unique_lock<std::shared_mutex> lock(mutex_);
+ return map_.erase(key) > 0;
+ }
+
+private:
+ uArchMap() = default;
+ uArchMap(const uArchMap &) = delete;
+ uArchMap &operator=(const uArchMap &) = delete;
+
+ mutable std::shared_mutex mutex_;
+ std::map<std::string, std::shared_ptr<uArch>> map_;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_BASE_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
new file mode 100644
index 0000000000000..27d44c38317a1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchInterfaces.h
@@ -0,0 +1,75 @@
+//===--- uArchInterfaces.h ---*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Defines the utility interfaces that are implemented by individual
+/// instructions.
+///
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <map>
+#include <string>
+#include <vector>
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+enum class MMAOpndEnum { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+ // Get supported Matrix shapes
+ virtual std::vector<std::pair<uint32_t, uint32_t>>
+ getSupportedShapes(mlir::Type dataType, MMAOpndEnum matrixType) = 0;
+
+ // @TODO: This method takes an context object as a parameter, this is to
+ // create the mlir::Type objects from the same context. Since type objects are
+ // uniqued in a specific context, to do things like "aType == bType" (where
+ // aType and bType are both same type) kind of checks, the both types should
+ // be from the same context.
+ //
+ // One alternative to this is to create enum to represent each types, but this
+ // adds an extra burden to user to convert these enums to specific types. In
+ // fact the utility that would convert enumToType() and vice versa would still
+ // have to use the context object.
+ //
+ // Untill we have a better solution, we stick to passing context object to
+ // this method.
+ virtual std::vector<mlir::Type> getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) = 0;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) = 0;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, mlir::Type AType,
+ mlir::Type BType, mlir::Type CType,
+ mlir::Type DType) = 0;
+ virtual std::vector<uint32_t> getSupportedM(mlir::Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedK(mlir::Type type) = 0;
+ virtual std::vector<uint32_t> getSupportedN(mlir::Type type) = 0;
+
+ virtual ~MMAInstructionInterface() = default;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTERFACES_H
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ff55f17315cfd..2371dadb6c886 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -128,5 +128,6 @@ add_mlir_dialect_library(MLIRXeVMDialect
MLIRDialectUtils
MLIRIR
MLIRLLVMDialect
+ MLIRXeGPUuArch
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 24e6a9c284e26..9534897cd232a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FileSystem.h"
@@ -345,6 +346,14 @@ XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
}
void XeVMDialect::initialize() {
+ // Populate the uArchMap with the supported target devices
+ auto pvcuArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+ auto bmguArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..9079df050ab2b 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(uArch)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7c6a4f37db9af..ad4bbaec9fb6d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRArithUtils
MLIRDialectUtils
MLIRIR
+ MLIRXeGPUuArch
MLIRViewLikeInterface
MLIRVectorDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d997296a22c20..fde0cf85caccc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -35,6 +36,14 @@ void XeGPUDialect::initialize() {
#define GET_ATTRDEF_LIST
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
>();
+
+ // Populate the uArchMap with the supported target devices
+ auto pvcuArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::PVCuArch::PVCuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("pvc", pvcuArch);
+ auto bmguArch =
+ std::make_shared<mlir::xegpu::uArch::Xe2Plus::BMGuArch::BMGuArch>();
+ mlir::xegpu::uArch::uArchMap::instance().insert("bmg", bmguArch);
}
/// Generates instructions to compute offsets for a subgroup identified by
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 9c178d1d85642..63acd30646764 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRTransforms
MLIRGPUDialect
MLIRXeGPUUtils
+ MLIRXeGPUuArch
MLIRGPUUtils
MLIRVectorTransforms
)
diff --git a/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
new file mode 100644
index 0000000000000..c7f691cb6dda7
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRXeGPUuArch
+ IntelGpuXe2.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/uArch
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRDialectUtils
+)
+
diff --git a/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
new file mode 100644
index 0000000000000..d80d1439692e8
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/uArch/IntelGpuXe2.cpp
@@ -0,0 +1,197 @@
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/YAMLTraits.h"
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace mlir::xegpu::uArch;
+using namespace mlir::xegpu::uArch::Xe2Plus;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+
+std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(mlir::Type dataType,
+ MMAOpndEnum matrixType) {
+ auto combineVectors = [](const std::vector<uint32_t> &a,
+ const std::vector<uint32_t> &b)
+ -> std::vector<std::pair<uint32_t, uint32_t>> {
+ std::vector<std::pair<uint32_t, uint32_t>> result;
+ for (unsigned x : a) {
+ for (unsigned y : b) {
+ result.emplace_back(x, y);
+ }
+ }
+ return result;
+ };
+
+ auto M = getSupportedM(dataType);
+ auto K = getSupportedK(dataType);
+ auto N = getSupportedN(dataType);
+ std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ resultMatrix = combineVectors(M, K);
+ break;
+ case MMAOpndEnum::MatrixB:
+ resultMatrix = combineVectors(K, N);
+ break;
+ case MMAOpndEnum::MatrixC:
+ resultMatrix = combineVectors(M, N);
+ break;
+ case MMAOpndEnum::MatrixD:
+ resultMatrix = combineVectors(M, N);
+ break;
+ }
+ return resultMatrix;
+}
+
+std::vector<mlir::Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) {
+ mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
+ mlir::Type f16Type = mlir::Float16Type::get(&context);
+ mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
+ mlir::Type f32Type = mlir::Float32Type::get(&context);
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixB:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixC:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ case MMAOpndEnum::MatrixD:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ }
+}
+
+bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) {
+ if (AType.isF16() || BType.isF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+ (!DType.isF32() && !DType.isF16())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f, hf | f, hf | hf | hf \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (AType.isBF16() || BType.isBF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+ (!DType.isF32() && !DType.isBF16())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f, bf | f, bf | bf | bf \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (AType.isTF32() || BType.isTF32()) {
+ if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+ (!DType.isF32())) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B \n"
+ << " f | f | tf32 | tf32 \n"
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+ } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+ AType.isInteger(8)) &&
+ !(BType.isInteger(2) || BType.isInteger(4) ||
+ BType.isInteger(8))) {
+ llvm::errs()
+ << "Unsupported dpas combinations of Dst, Acc, A and B matrices, "
+ << "Supported types are:\n"
+ << " Dst | Acc | A | B "
+ " \n"
+ << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 "
+ << "AType: " << AType << " BType: " << BType << " CType: " << CType
+ << " DType: " << DType;
+ return false;
+ }
+
+ return true;
+}
+
+bool DPASInstruction::checkSupportedShapesAndTypes(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType, mlir::Type CType, mlir::Type DType) {
+ auto supportedAShapes = getSupportedShapes(AType, MMAOpndEnum::MatrixA);
+ auto supportedBShapes = getSupportedShapes(BType, MMAOpndEnum::MatrixB);
+ auto supportedCShapes = getSupportedShapes(CType, MMAOpndEnum::MatrixC);
+ auto supportedDShapes = getSupportedShapes(DType, MMAOpndEnum::MatrixD);
+ return llvm::is_contained(supportedAShapes, AShape) &&
+ llvm::is_contained(supportedBShapes, BShape) &&
+ llvm::is_contained(supportedCShapes, CShape) &&
+ llvm::is_contained(supportedDShapes, DShape) &&
+ checkSupportedTypes(AType, BType, CType, DType);
+}
+
+bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) {
+ return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+ BType, CType, DType);
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedM(mlir::Type type) {
+ return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedK(mlir::Type type) {
+ // assert if data type is not int or float type
+ assert(type.isIntOrFloat() && "Matrix type must be int or float");
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ uint32_t kSize = 0;
+ switch (bitWidth) {
+ case 2:
+ kSize = 64;
+ break;
+ case 4:
+ kSize = 64;
+ break;
+ case 8:
+ kSize = 32;
+ break;
+ case 16:
+ kSize = 16;
+ break;
+ case 32:
+ kSize = 8;
+ break;
+ default:
+ llvm_unreachable("Invalid int or float");
+ }
+ return {kSize};
+}
+
+std::vector<uint32_t> DPASInstruction::getSupportedN(mlir::Type type) {
+ return {16};
+}
+
+} // namespace Xe2Plus
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
More information about the Mlir-commits
mailing list