[Mlir-commits] [mlir] [uArch][XeGPU] Add XeGPU uArch definition. (PR #153706)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Aug 21 09:02:19 PDT 2025
================
@@ -0,0 +1,197 @@
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/YAMLTraits.h"
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace mlir::xegpu::uArch;
+using namespace mlir::xegpu::uArch::Xe2Plus;
+
+namespace mlir {
+namespace xegpu {
+namespace uArch {
+namespace Xe2Plus {
+
+std::vector<std::pair<uint32_t, uint32_t>>
+DPASInstruction::getSupportedShapes(mlir::Type dataType,
+ MMAOpndEnum matrixType) {
+ auto combineVectors = [](const std::vector<uint32_t> &a,
+ const std::vector<uint32_t> &b)
+ -> std::vector<std::pair<uint32_t, uint32_t>> {
+ std::vector<std::pair<uint32_t, uint32_t>> result;
+ for (unsigned x : a) {
+ for (unsigned y : b) {
+ result.emplace_back(x, y);
+ }
+ }
+ return result;
+ };
+
+ auto M = getSupportedM(dataType);
+ auto K = getSupportedK(dataType);
+ auto N = getSupportedN(dataType);
+ std::vector<std::pair<unsigned, unsigned>> resultMatrix;
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ resultMatrix = combineVectors(M, K);
+ break;
+ case MMAOpndEnum::MatrixB:
+ resultMatrix = combineVectors(K, N);
+ break;
+ case MMAOpndEnum::MatrixC:
+ resultMatrix = combineVectors(M, N);
+ break;
+ case MMAOpndEnum::MatrixD:
+ resultMatrix = combineVectors(M, N);
+ break;
+ }
+ return resultMatrix;
+}
+
+std::vector<mlir::Type>
+DPASInstruction::getSupportedTypes(MLIRContext &context,
+ MMAOpndEnum matrixType) {
+ mlir::Type bf16Type = mlir::BFloat16Type::get(&context);
+ mlir::Type f16Type = mlir::Float16Type::get(&context);
+ mlir::Type tf32Type = mlir::FloatTF32Type::get(&context);
+ mlir::Type f32Type = mlir::Float32Type::get(&context);
+
+ switch (matrixType) {
+ case MMAOpndEnum::MatrixA:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixB:
+ return {bf16Type, f16Type, tf32Type};
+ break;
+ case MMAOpndEnum::MatrixC:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ case MMAOpndEnum::MatrixD:
+ return {bf16Type, f16Type, f32Type};
+ break;
+ }
+}
+
+bool DPASInstruction::checkSupportedTypes(mlir::Type AType, mlir::Type BType,
+ mlir::Type CType, mlir::Type DType) {
+ if (AType.isF16() || BType.isF16()) {
+ if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
+ (!DType.isF32() && !DType.isF16())) {
+ llvm::errs()
----------------
adam-smnk wrote:
I see using this helper as a part of pass matcher. I definitely don't want to get spammed with errors 😉
Overall, these message are really verbose and I'm not sure if it's that useful.
Maybe a table of all supported combination could a part of function docs (source or header)?
A shorter error could be hidden under debug `LDBG() << "msg"`
https://github.com/llvm/llvm-project/pull/153706
More information about the Mlir-commits
mailing list