[Mlir-commits] [mlir] 83d3a2e - [uArch][XeGPU] Add XeGPU uArch definition. (#153706)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 10 07:21:31 PDT 2025
Author: Md Abdullah Shahneous Bari
Date: 2025-10-10T09:21:27-05:00
New Revision: 83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0
URL: https://github.com/llvm/llvm-project/commit/83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0
DIFF: https://github.com/llvm/llvm-project/commit/83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0.diff
LOG: [uArch][XeGPU] Add XeGPU uArch definition. (#153706)
The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary
components (e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops (e.g.,
mma ops, 2DBlockIO ops). The implementation of these interfaces are
provided by the specific instructions. Each family of ops provides these
5 common APIs. However, some family of ops may have more utility APIs.
The common 5 APIs are:
- getSupportedShapes
- getSupportedTypes
- checkSupportedShapesAndTypes
- checkSupportedTypes
- validate
Add support for PVC and BMG architectures.
Add support for DPAS instruction.
Added:
mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
Modified:
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..0519f7b2e277d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,297 @@
+//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
+// This file defines the uArch details for Xe2 and its derived architectures.
+// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/DebugLog.h"
+#include <map>
+#include <string>
+
+#define DEBUG_TYPE "xegpu-uarch"
+
+using namespace mlir;
+using namespace mlir::xegpu::uArch;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+struct Xe2Plus : public uArch {
+ XeCoreInfo xeCore;
+ Xe2Plus(const std::string &archName, const std::string &archDescription,
+ const XeCoreInfo &xeCore,
+ const std::map<RegisterFileType, RegisterFileInfo> ®Info = {},
+ const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
+ &instrs = {})
+ : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
+ xeCore(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+ DPASInstruction()
+ : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+
+ // Override all virtuals from MatrixOpInterface
+ virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+ getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
+ virtual llvm::SmallVector<Type, 8>
+ getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) override;
+ virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+ Type DType) override;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+};
+
+struct PVCuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+ PVCuArch()
+ : Xe2Plus("pvc", // archName
+ "Ponte Vecchio Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+ {/* registerFileInfo */}, // Optional: empty
+ {/* cacheInfo */}, // Optional: empty
+ {/* instructions */} // Optional: empty
+ ) {
+ // Intialize register file info
+ // GRF
+ this->registerFileInfo.emplace(
+ RegisterFileType::GRF,
+ RegisterFileInfo(
+ 64 * 1024, // size in bits
+ {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+ {128, 256} // registers per thread per mode
+ ));
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cacheInfo.push_back(
+ CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
+ // L2 cache, XeStack level
+ this->cacheInfo.push_back(
+ CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
+
+ // Add the instructions-
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getInstructionKind(), dpas);
+ owned_instructions.push_back(dpas);
+ }
+};
+
+struct BMGuArch : public Xe2Plus {
+ // Maintaines ownership of the instructions owned by PVUarch
+ llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+ BMGuArch()
+ : Xe2Plus("bmg", // archName
+ "Battlemage Architecture", // archDescription
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+ {/* registerFileInfo */}, // Optional: empty
+ {/* cacheInfo */}, // Optional: empty
+ {/* instructions */} // Optional: empty
+ ) {
+ // Intialize register file info
+ // GRF
+ this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
+ 64 * 1024, // size in bits
+ {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+ {128, 256} // registers per thread per mode
+ );
+ // Initialize cache info
+ // L1 cache, XeCore level
+ this->cacheInfo.push_back(
+ CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
+ // L2 cache, XeStack level
+ this->cacheInfo.push_back(
+ CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
+
+ // Add the instructions
+ auto dpas = std::make_shared<DPASInstruction>();
+ instructions.emplace(dpas->getInstructionKind(), dpas);
+ owned_instructions.push_back(dpas);
+ }
+};
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+ auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
+ const llvm::SmallVector<uint32_t, 8> &b)
+ -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
+ llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
+ for (unsigned x : a) {
+ for (unsigned y : b) {
+ result.emplace_back(x, y);
+ }
+ }
+ return result;
+ };
+
+ auto M = getSupportedM(dataType);
+ auto K = getSupportedK(dataType);
+ auto N = getSupportedN(dataType);
+ llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
+
+ switch (matrixType) {
+ case MMAOpndKind::MatrixA:
+ resultMatrix = combineVectors(M, K);
+ break;
+ case MMAOpndKind::MatrixB:
+ resultMatrix = combineVectors(K, N);
+ break;
+ case MMAOpndKind::MatrixC:
+ resultMatrix = combineVectors(M, N);
+ break;
+ case MMAOpndKind::MatrixD:
+ resultMatrix = combineVectors(M, N);
+ break;
+ }
+ return resultMatrix;
+}
+
+inline llvm::SmallVector<Type, 8>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) {
+ Type bf16Type = BFloat16Type::get(&context);
+ Type f16Type = Float16Type::get(&context);
+ Type tf32Type = FloatTF32Type::get(&context);
+ Type f32Type = Float32Type::get(&context);
+
+ switch (matrixType) {
+ case MMAOpndKind::MatrixA:
+ return {bf16Type, f16Type, tf32Type};
+ case MMAOpndKind::MatrixB:
+ return {bf16Type, f16Type, tf32Type};
+ case MMAOpndKind::MatrixC:
+ return {bf16Type, f16Type, f32Type};
+ case MMAOpndKind::MatrixD:
+ return {bf16Type, f16Type, f32Type};
+ }
+ return {};
+}
+
+inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
+ Type CType, Type DType) {
+ if (AType.isF16() || BType.isF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+ (!DType.isF32() && !DType.isF16())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (AType.isBF16() || BType.isBF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+ (!DType.isF32() && !DType.isBF16())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (AType.isTF32() || BType.isTF32()) {
+ if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+ (!DType.isF32())) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+ } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+ AType.isInteger(8)) &&
+ !(BType.isInteger(2) || BType.isInteger(4) ||
+ BType.isInteger(8))) {
+ LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+ return false;
+ }
+
+ return true;
+}
+
+inline bool DPASInstruction::checkSupportedShapesAndTypes(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType, Type DType) {
+ auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+ auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+ auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+ auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
+ return llvm::is_contained(supportedAShapes, AShape) &&
+ llvm::is_contained(supportedBShapes, BShape) &&
+ llvm::is_contained(supportedCShapes, CShape) &&
+ llvm::is_contained(supportedDShapes, DShape) &&
+ checkSupportedTypes(AType, BType, CType, DType);
+}
+
+inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType,
+ Type DType) {
+ return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+ BType, CType, DType);
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedM(Type type) {
+ return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedK(Type type) {
+ // assert if data type is not int or float type
+ assert(type.isIntOrFloat() && "Matrix type must be int or float");
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ uint32_t kSize = 0;
+ switch (bitWidth) {
+ case 2:
+ kSize = 64;
+ break;
+ case 4:
+ kSize = 64;
+ break;
+ case 8:
+ kSize = 32;
+ break;
+ case 16:
+ kSize = 16;
+ break;
+ case 32:
+ kSize = 8;
+ break;
+ default:
+ llvm_unreachable("Invalid int or float");
+ }
+ return {kSize};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedN(Type type) {
+ return {16};
+}
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..955994ea5ecf5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,265 @@
+//===- uArch.h --------------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// Base uArch definition for
diff erent architectures.
+//
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
+enum class InstructionKind {
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+ // multiply-add operation
+ // @TODO: Add more instructions as needed
+};
+
+// A struct to represent basic information about an instruction.
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed.
+struct Instruction {
+ Instruction(InstructionKind kind, InstructionScope scope)
+ : instKind(kind), scope(scope) {}
+
+ virtual ~Instruction() = default;
+ // Get methods
+ InstructionKind getInstructionKind() { return instKind; }
+ InstructionScope getScope() { return scope; }
+ static llvm::StringRef toString(InstructionKind instKind) {
+ switch (instKind) {
+ case InstructionKind::DPAS:
+ return "dpas";
+ }
+ llvm_unreachable("Unknown InstructionKind");
+ }
+
+ static std::optional<InstructionKind>
+ parseInstructionKind(llvm::StringRef str) {
+ if (str.equals_insensitive("dpas"))
+ return InstructionKind::DPAS;
+ return std::nullopt;
+ }
+
+protected:
+ InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+ InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
+ // workgroup, cluster)
+ // @TODO: Add more fields as needed
+};
+
+enum class RegisterFileMode : uint8_t { Small, Large };
+enum class RegisterFileType : uint8_t { GRF, ARF };
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+ // Constructor
+ RegisterFileInfo() = default;
+ RegisterFileInfo(uint32_t size,
+ const llvm::SmallVector<RegisterFileMode, 4> &mode,
+ const llvm::SmallVector<uint32_t, 4> &numRegs)
+ : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+
+ const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
+ return mode;
+ }
+
+ const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
+ return numRegsPerThreadPerMode;
+ }
+
+protected:
+ uint32_t size; // size per register in bits
+ llvm::SmallVector<RegisterFileMode, 4>
+ mode; // e.g., "small", "large" GRF modes
+ llvm::SmallVector<uint32_t, 4>
+ numRegsPerThreadPerMode; // number of registers per thread per mode
+};
+
+enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
+
+// A struct to represent cache information
+struct CacheInfo {
+ // Constructor
+ CacheInfo() = default;
+ CacheInfo(uint32_t size, uint32_t line_size,
+ CacheHierarchyLevel hierarchy_level)
+ : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
+
+ virtual ~CacheInfo() = default;
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+ uint32_t getLineSize() const { return line_size; }
+ CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; }
+
+protected:
+ uint32_t size;
+ uint32_t line_size;
+ CacheHierarchyLevel hierarchy_level;
+ // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+ // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
+ // latency, throughput, bandwidth)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device.
+struct uArch {
+ // Constructor
+ uArch(
+ const std::string &name, const std::string &description,
+ const std::map<RegisterFileType, RegisterFileInfo> ®isterFileInfo = {},
+ const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+ const std::map<InstructionKind, std::shared_ptr<Instruction>>
+ &instructions = {})
+ : name(name), description(description),
+ registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
+ instructions(instructions) {}
+
+ // Get methods
+ const std::string &getName() const { return name; }
+
+ const std::string &getDescription() const { return description; }
+
+ const std::map<RegisterFileType, RegisterFileInfo> &
+ getRegisterFileInfo() const {
+ return registerFileInfo;
+ }
+
+ const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
+ return cacheInfo;
+ }
+
+ const std::map<InstructionKind, std::shared_ptr<Instruction>> &
+ getInstructions() const {
+ return instructions;
+ }
+
+ // Get the name of the supported instruction names for that
+ // architecture. It returns the names of the instructions added to the uArch.
+ llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
+ llvm::SmallVector<StringRef, 8> instructionNames;
+ for (const auto &inst : instructions) {
+ instructionNames.push_back(Instruction::toString(inst.first));
+ }
+ return instructionNames;
+ }
+
+ // Checks if an instruction is supported in this uArch
+ bool checkSupportedInstruction(InstructionKind instr) const {
+ return instructions.find(instr) != instructions.end();
+ }
+
+protected:
+ std::string name; // Name of the uArch, similar to target triple
+ std::string description;
+ std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
+ llvm::SmallVector<CacheInfo, 4> cacheInfo;
+ std::map<InstructionKind, std::shared_ptr<Instruction>>
+ instructions; // set of instructions supported by the uArch
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+ // Constructor
+ SharedMemory(uint32_t size, uint32_t alignment)
+ : size(size), alignment(alignment) {}
+
+ // Get methods
+ uint32_t getSize() const { return size; }
+ uint32_t getAlignment() const { return alignment; }
+
+protected:
+ uint32_t size; // in bytes
+ uint32_t alignment; // in bytes
+ // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct XeCoreInfo {
+ uint32_t num_threads;
+ SharedMemory shared_memory;
+ uint32_t num_vector_units;
+ uint32_t num_matrix_units;
+
+ XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+ uint32_t num_vector_units, uint32_t num_matrix_units)
+ : num_threads(num_threads), shared_memory(shared_memory),
+ num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Interfaces
+//===----------------------------------------------------------------------===//
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+ // Get supported Matrix shapes
+ virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+ getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
+ // @TODO: This method takes an context object as a parameter, this is to
+ // create the Type objects from the same context. Since type objects are
+ // uniqued in a specific context, to do things like "aType == bType" (where
+ // aType and bType are both same type) kind of checks, the both types should
+ // be from the same context.
+ //
+ // One alternative to this is to create enum to represent each types, but this
+ // adds an extra burden to user to convert these enums to specific types. In
+ // fact the utility that would convert enumToType() and vice versa would still
+ // have to use the context object.
+ //
+ // Untill we have a better solution, we stick to passing context object to
+ // this method.
+ virtual llvm::SmallVector<Type, 8>
+ getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
+ virtual bool
+ checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) = 0;
+ virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+ Type DType) = 0;
+ virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+ std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape,
+ std::pair<uint32_t, uint32_t> DShape, Type AType,
+ Type BType, Type CType, Type DType) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+
+ virtual ~MMAInstructionInterface() = default;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..9beb22d517473 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
More information about the Mlir-commits
mailing list