[Mlir-commits] [mlir] 83d3a2e - [uArch][XeGPU] Add XeGPU uArch definition. (#153706)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 10 07:21:31 PDT 2025


Author: Md Abdullah Shahneous Bari
Date: 2025-10-10T09:21:27-05:00
New Revision: 83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0

URL: https://github.com/llvm/llvm-project/commit/83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0
DIFF: https://github.com/llvm/llvm-project/commit/83d3a2efe462e16165ef2c9d58e30f8d3ad2f2b0.diff

LOG: [uArch][XeGPU] Add XeGPU uArch definition. (#153706)

The uArch infrastructure provides:
- A set data structures to represent, uArch and it's necessary
components (e.g., instructions, register-files, caches).
- A set of utility interfaces that are common to a family of ops (e.g.,
mma ops, 2DBlockIO ops). The implementation of these interfaces are
provided by the specific instructions. Each family of ops provides these
5 common APIs. However, some family of ops may have more utility APIs.
The common 5 APIs are:
	- getSupportedShapes
	- getSupportedTypes
	- checkSupportedShapesAndTypes
	- checkSupportedTypes
	- validate

Add support for PVC and BMG architectures.
Add support for DPAS instruction.

Added: 
    mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
    mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h

Modified: 
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
new file mode 100644
index 0000000000000..0519f7b2e277d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -0,0 +1,297 @@
+//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
+// This file defines the uArch details for Xe2 and its derived architectures.
+// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
+
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/DebugLog.h"
+#include <map>
+#include <string>
+
+#define DEBUG_TYPE "xegpu-uarch"
+
+using namespace mlir;
+using namespace mlir::xegpu::uArch;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+struct Xe2Plus : public uArch {
+  XeCoreInfo xeCore;
+  Xe2Plus(const std::string &archName, const std::string &archDescription,
+          const XeCoreInfo &xeCore,
+          const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
+          const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+          const std::map<InstructionKind, std::shared_ptr<Instruction>>
+              &instrs = {})
+      : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
+        xeCore(xeCore) {}
+};
+
+// struct to represent DPAS instruction
+struct DPASInstruction : public Instruction, public MMAInstructionInterface {
+  DPASInstruction()
+      : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+
+  // Override all virtuals from MatrixOpInterface
+  virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+  getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
+  virtual llvm::SmallVector<Type, 8>
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape, Type AType,
+                               Type BType, Type CType, Type DType) override;
+  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+                                   Type DType) override;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, Type AType,
+                        Type BType, Type CType, Type DType) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+};
+
+struct PVCuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+  PVCuArch()
+      : Xe2Plus("pvc",                        // archName
+                "Ponte Vecchio Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
+                {/* registerFileInfo */}, // Optional: empty
+                {/* cacheInfo */},        // Optional: empty
+                {/* instructions */}      // Optional: empty
+        ) {
+    // Intialize register file info
+    // GRF
+    this->registerFileInfo.emplace(
+        RegisterFileType::GRF,
+        RegisterFileInfo(
+            64 * 1024,                                          // size in bits
+            {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+            {128, 256} // registers per thread per mode
+            ));
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cacheInfo.push_back(
+        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
+    // L2 cache, XeStack level
+    this->cacheInfo.push_back(
+        CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
+
+    // Add the instructions-
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getInstructionKind(), dpas);
+    owned_instructions.push_back(dpas);
+  }
+};
+
+struct BMGuArch : public Xe2Plus {
+  // Maintaines ownership of the instructions owned by PVUarch
+  llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+  BMGuArch()
+      : Xe2Plus("bmg",                     // archName
+                "Battlemage Architecture", // archDescription
+                XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
+                {/* registerFileInfo */}, // Optional: empty
+                {/* cacheInfo */},        // Optional: empty
+                {/* instructions */}      // Optional: empty
+        ) {
+    // Intialize register file info
+    // GRF
+    this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
+        64 * 1024,                                          // size in bits
+        {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
+        {128, 256} // registers per thread per mode
+    );
+    // Initialize cache info
+    // L1 cache, XeCore level
+    this->cacheInfo.push_back(
+        CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
+    // L2 cache, XeStack level
+    this->cacheInfo.push_back(
+        CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
+
+    // Add the instructions
+    auto dpas = std::make_shared<DPASInstruction>();
+    instructions.emplace(dpas->getInstructionKind(), dpas);
+    owned_instructions.push_back(dpas);
+  }
+};
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+  auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
+                           const llvm::SmallVector<uint32_t, 8> &b)
+      -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
+    llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
+    for (unsigned x : a) {
+      for (unsigned y : b) {
+        result.emplace_back(x, y);
+      }
+    }
+    return result;
+  };
+
+  auto M = getSupportedM(dataType);
+  auto K = getSupportedK(dataType);
+  auto N = getSupportedN(dataType);
+  llvm::SmallVector<std::pair<unsigned, unsigned>, 16> resultMatrix;
+
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    resultMatrix = combineVectors(M, K);
+    break;
+  case MMAOpndKind::MatrixB:
+    resultMatrix = combineVectors(K, N);
+    break;
+  case MMAOpndKind::MatrixC:
+    resultMatrix = combineVectors(M, N);
+    break;
+  case MMAOpndKind::MatrixD:
+    resultMatrix = combineVectors(M, N);
+    break;
+  }
+  return resultMatrix;
+}
+
+inline llvm::SmallVector<Type, 8>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+                                   MMAOpndKind matrixType) {
+  Type bf16Type = BFloat16Type::get(&context);
+  Type f16Type = Float16Type::get(&context);
+  Type tf32Type = FloatTF32Type::get(&context);
+  Type f32Type = Float32Type::get(&context);
+
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    return {bf16Type, f16Type, tf32Type};
+  case MMAOpndKind::MatrixB:
+    return {bf16Type, f16Type, tf32Type};
+  case MMAOpndKind::MatrixC:
+    return {bf16Type, f16Type, f32Type};
+  case MMAOpndKind::MatrixD:
+    return {bf16Type, f16Type, f32Type};
+  }
+  return {};
+}
+
+inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
+                                                 Type CType, Type DType) {
+  if (AType.isF16() || BType.isF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+        (!DType.isF32() && !DType.isF16())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (AType.isBF16() || BType.isBF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
+        (!DType.isF32() && !DType.isBF16())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (AType.isTF32() || BType.isTF32()) {
+    if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
+        (!DType.isF32())) {
+      LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+      return false;
+    }
+  } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
+               AType.isInteger(8)) &&
+             !(BType.isInteger(2) || BType.isInteger(4) ||
+               BType.isInteger(8))) {
+    LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
+    return false;
+  }
+
+  return true;
+}
+
+inline bool DPASInstruction::checkSupportedShapesAndTypes(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    Type AType, Type BType, Type CType, Type DType) {
+  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
+  return llvm::is_contained(supportedAShapes, AShape) &&
+         llvm::is_contained(supportedBShapes, BShape) &&
+         llvm::is_contained(supportedCShapes, CShape) &&
+         llvm::is_contained(supportedDShapes, DShape) &&
+         checkSupportedTypes(AType, BType, CType, DType);
+}
+
+inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
+                                      std::pair<uint32_t, uint32_t> BShape,
+                                      std::pair<uint32_t, uint32_t> CShape,
+                                      std::pair<uint32_t, uint32_t> DShape,
+                                      Type AType, Type BType, Type CType,
+                                      Type DType) {
+  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+                                      BType, CType, DType);
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedM(Type type) {
+  return {1, 2, 3, 4, 5, 6, 7, 8};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedK(Type type) {
+  // assert if data type is not int or float type
+  assert(type.isIntOrFloat() && "Matrix type must be int or float");
+  auto bitWidth = type.getIntOrFloatBitWidth();
+  uint32_t kSize = 0;
+  switch (bitWidth) {
+  case 2:
+    kSize = 64;
+    break;
+  case 4:
+    kSize = 64;
+    break;
+  case 8:
+    kSize = 32;
+    break;
+  case 16:
+    kSize = 16;
+    break;
+  case 32:
+    kSize = 8;
+    break;
+  default:
+    llvm_unreachable("Invalid int or float");
+  }
+  return {kSize};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+DPASInstruction::getSupportedN(Type type) {
+  return {16};
+}
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H

diff  --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
new file mode 100644
index 0000000000000..955994ea5ecf5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -0,0 +1,265 @@
+//===- uArch.h --------------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// Base uArch definition for 
diff erent architectures.
+//
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+#define MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H
+
+#include <any>
+#include <functional>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <shared_mutex>
+#include <tuple>
+
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+
+// An enum class to represent the scope of an instruction
+enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
+enum class InstructionKind {
+  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+        // multiply-add operation
+  // @TODO: Add more instructions as needed
+};
+
+// A struct to represent basic information about an instruction.
+// The primary purpose of the Instruction struct is to provide a generic way to
+// represent information about an instruction and to use this information to
+// generate the uArch. Specifc instruction in a uArch can inherit from this
+// struct and add more fields as needed.
+struct Instruction {
+  Instruction(InstructionKind kind, InstructionScope scope)
+      : instKind(kind), scope(scope) {}
+
+  virtual ~Instruction() = default;
+  // Get methods
+  InstructionKind getInstructionKind() { return instKind; }
+  InstructionScope getScope() { return scope; }
+  static llvm::StringRef toString(InstructionKind instKind) {
+    switch (instKind) {
+    case InstructionKind::DPAS:
+      return "dpas";
+    }
+    llvm_unreachable("Unknown InstructionKind");
+  }
+
+  static std::optional<InstructionKind>
+  parseInstructionKind(llvm::StringRef str) {
+    if (str.equals_insensitive("dpas"))
+      return InstructionKind::DPAS;
+    return std::nullopt;
+  }
+
+protected:
+  InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+  InstructionScope scope;   // scope of the instruction (e.g., lane, subgroup,
+                            // workgroup, cluster)
+  // @TODO: Add more fields as needed
+};
+
+enum class RegisterFileMode : uint8_t { Small, Large };
+enum class RegisterFileType : uint8_t { GRF, ARF };
+
+// A struct to represent register file information
+struct RegisterFileInfo {
+  // Constructor
+  RegisterFileInfo() = default;
+  RegisterFileInfo(uint32_t size,
+                   const llvm::SmallVector<RegisterFileMode, 4> &mode,
+                   const llvm::SmallVector<uint32_t, 4> &numRegs)
+      : size(size), mode(mode), numRegsPerThreadPerMode(numRegs) {}
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+
+  const llvm::SmallVector<RegisterFileMode, 4> &getModes() const {
+    return mode;
+  }
+
+  const llvm::SmallVector<uint32_t, 4> &getNumRegsPerThreadPerMode() const {
+    return numRegsPerThreadPerMode;
+  }
+
+protected:
+  uint32_t size; // size per register in bits
+  llvm::SmallVector<RegisterFileMode, 4>
+      mode; // e.g., "small", "large" GRF modes
+  llvm::SmallVector<uint32_t, 4>
+      numRegsPerThreadPerMode; // number of registers per thread per mode
+};
+
+enum class CacheHierarchyLevel { L1 = 1, L2 = 2, L3 = 3 };
+
+// A struct to represent cache information
+struct CacheInfo {
+  // Constructor
+  CacheInfo() = default;
+  CacheInfo(uint32_t size, uint32_t line_size,
+            CacheHierarchyLevel hierarchy_level)
+      : size(size), line_size(line_size), hierarchy_level(hierarchy_level) {}
+
+  virtual ~CacheInfo() = default;
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+  uint32_t getLineSize() const { return line_size; }
+  CacheHierarchyLevel getHierarchyLevel() const { return hierarchy_level; }
+
+protected:
+  uint32_t size;
+  uint32_t line_size;
+  CacheHierarchyLevel hierarchy_level;
+  // @TODO: Add more fields as needed (e.g., associativity, num_banks,
+  // bank_size, num_ports, port_width, bank_conflicts, hierarchy_level,
+  // latency, throughput, bandwidth)
+};
+
+// A struct to represent the uArch
+// This struct is used to represent the microarchitecture of a target device.
+struct uArch {
+  // Constructor
+  uArch(
+      const std::string &name, const std::string &description,
+      const std::map<RegisterFileType, RegisterFileInfo> &registerFileInfo = {},
+      const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
+      const std::map<InstructionKind, std::shared_ptr<Instruction>>
+          &instructions = {})
+      : name(name), description(description),
+        registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
+        instructions(instructions) {}
+
+  // Get methods
+  const std::string &getName() const { return name; }
+
+  const std::string &getDescription() const { return description; }
+
+  const std::map<RegisterFileType, RegisterFileInfo> &
+  getRegisterFileInfo() const {
+    return registerFileInfo;
+  }
+
+  const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
+    return cacheInfo;
+  }
+
+  const std::map<InstructionKind, std::shared_ptr<Instruction>> &
+  getInstructions() const {
+    return instructions;
+  }
+
+  // Get the name of the supported instruction names for that
+  // architecture. It returns the names of the instructions added to the uArch.
+  llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
+    llvm::SmallVector<StringRef, 8> instructionNames;
+    for (const auto &inst : instructions) {
+      instructionNames.push_back(Instruction::toString(inst.first));
+    }
+    return instructionNames;
+  }
+
+  // Checks if an instruction is supported in this uArch
+  bool checkSupportedInstruction(InstructionKind instr) const {
+    return instructions.find(instr) != instructions.end();
+  }
+
+protected:
+  std::string name; // Name of the uArch, similar to target triple
+  std::string description;
+  std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
+  llvm::SmallVector<CacheInfo, 4> cacheInfo;
+  std::map<InstructionKind, std::shared_ptr<Instruction>>
+      instructions; // set of instructions supported by the uArch
+};
+
+// A struct to represent shared memory information
+struct SharedMemory {
+  // Constructor
+  SharedMemory(uint32_t size, uint32_t alignment)
+      : size(size), alignment(alignment) {}
+
+  // Get methods
+  uint32_t getSize() const { return size; }
+  uint32_t getAlignment() const { return alignment; }
+
+protected:
+  uint32_t size;      // in bytes
+  uint32_t alignment; // in bytes
+  // @TODO: Add more fields as needed (e.g., latency, throughput, bandwidth)
+};
+
+struct XeCoreInfo {
+  uint32_t num_threads;
+  SharedMemory shared_memory;
+  uint32_t num_vector_units;
+  uint32_t num_matrix_units;
+
+  XeCoreInfo(uint32_t num_threads, const SharedMemory &shared_memory,
+             uint32_t num_vector_units, uint32_t num_matrix_units)
+      : num_threads(num_threads), shared_memory(shared_memory),
+        num_vector_units(num_vector_units), num_matrix_units(num_matrix_units) {
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Interfaces
+//===----------------------------------------------------------------------===//
+enum class MMAOpndKind { MatrixA, MatrixB, MatrixC, MatrixD };
+struct MMAInstructionInterface {
+  // Get supported Matrix shapes
+  virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+  getSupportedShapes(Type dataType, MMAOpndKind matrixType) = 0;
+  // @TODO: This method takes an context object as a parameter, this is to
+  // create the Type objects from the same context. Since type objects are
+  // uniqued in a specific context, to do things like "aType == bType" (where
+  // aType and bType are both same type) kind of checks, the both types should
+  // be from the same context.
+  //
+  // One alternative to this is to create enum to represent each types, but this
+  // adds an extra burden to user to convert these enums to specific types. In
+  // fact the utility that would convert enumToType() and vice versa would still
+  // have to use the context object.
+  //
+  // Untill we have a better solution, we stick to passing context object to
+  // this method.
+  virtual llvm::SmallVector<Type, 8>
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) = 0;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape, Type AType,
+                               Type BType, Type CType, Type DType) = 0;
+  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+                                   Type DType) = 0;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, Type AType,
+                        Type BType, Type CType, Type DType) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
+  virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+
+  virtual ~MMAInstructionInterface() = default;
+};
+
+} // namespace uArch
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UARCH_UARCHBASE_H

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..9beb22d517473 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"


        


More information about the Mlir-commits mailing list