[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Aug 21 09:02:19 PDT 2025


================
@@ -0,0 +1,197 @@
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/YAMLTraits.h"
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace mlir::xegpu::uArch;
+using namespace mlir::xegpu::uArch::Xe2Plus;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+
+std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(mlir::Type dataType,
+                                    MMAOpndEnum matrixType) {
+  auto combineVectors = [](const std::vector<uint32_t> &a,
+                           const std::vector<uint32_t> &b)
+      -> std::vector<std::pair<uint32_t, uint32_t>> {
+    std::vector<std::pair<uint32_t, uint32_t>> result;
+    for (unsigned x : a) {
+      for (unsigned y : b) {
+        result.emplace_back(x, y);
+      }
+    }
+    return result;
+  };
+
+  auto M = getSupportedM(dataType);
+  auto K = getSupportedK(dataType);
+  auto N = getSupportedN(dataType);
+  std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+  switch (matrixType) {
+  case MMAOpndEnum::MatrixA:
+    resultMatrix = combineVectors(M, K);
+    break;
+  case MMAOpndEnum::MatrixB:
+    resultMatrix = combineVectors(K, N);
+    break;
+  case MMAOpndEnum::MatrixC:
+    resultMatrix = combineVectors(M, N);
+    break;
+  case MMAOpndEnum::MatrixD:
+    resultMatrix = combineVectors(M, N);
+    break;
+  }
+  return resultMatrix;
+}
+
+std::vector<mlir::Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+                                   MMAOpndEnum matrixType) {
+  mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
+  mlir::Type f16Type = mlir::Float16Type::get(&context);
+  mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
+  mlir::Type f32Type = mlir::Float32Type::get(&context);
+
+  switch (matrixType) {
+  case MMAOpndEnum::MatrixA:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndEnum::MatrixB:
+    return {bf16Type, f16Type, tf32Type};
+    break;
+  case MMAOpndEnum::MatrixC:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  case MMAOpndEnum::MatrixD:
+    return {bf16Type, f16Type, f32Type};
+    break;
+  }
+}
+
+bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+                                          mlir::Type CType, mlir::Type DType) {
+  if (AType.isF16() || BType.isF16()) {
+    if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+        (!DType.isF32() && !DType.isF16())) {
+      llvm::errs()
----------------
adam-smnk wrote:

I see using this helper as a part of pass matcher. I definitely don't want to get spammed with errors 😉

Overall, these message are really verbose and I'm not sure if it's that useful.
Maybe a table of all supported combination could a part of function docs (source or header)?

A shorter error could be hidden under debug `LDBG() << "msg"`

https://github.com/llvm/llvm-project/pull/153706


More information about the Mlir-commits mailing list