[Mlir-commits] [mlir] [mlir][x86] Support for `f8` AMX tiled dot-product. (PR #194786)
Arun Thangamani
llvmlistbot at llvm.org
Tue Apr 28 22:24:30 PDT 2026
https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/194786
This patch enable AMX tiled dot-product support for `f8E4M3FN` and `f8E5M2` types in MLIR by lowering to below llvm instrincs:
- `llvm.x86.tdpbf8ps`
- `llvm.x86.tdpbhf8ps`
- `llvm.x86.tdphbf8ps`
- `llvm.x86.tdphf8ps`
>From 414d0fb2ef48315204a8edda1d28a886f0b186dd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 28 Apr 2026 22:14:58 -0700
Subject: [PATCH] mlir support for f8 type lowering to llvm instrincs.
---
mlir/include/mlir/Dialect/X86/X86.td | 36 ++++++++---
mlir/lib/Dialect/X86/IR/X86Dialect.cpp | 11 +++-
.../Dialect/X86/AMX/legalize-for-llvm.mlir | 64 +++++++++++++++++++
3 files changed, 100 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 814bf884395bb..96b72c2e8c215 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -694,7 +694,7 @@ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
let mnemonic = "amx." # typeMnemonic;
}
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]> {
let cppFunctionName = "isValidTileTypeElementType";
}
@@ -743,16 +743,20 @@ class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::x86::AMXTileType">;
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]>;
def AMXTileF32 : AMXTileOf<[F32]>;
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+def AMXTileF16OrBF16OrF8 : AMXTileOf<[F16, BF16, F8E4M3FN, F8E5M2]>;
def AMXTileI32 : AMXTileOf<[I32]>;
def AMXTileI8 : AMXTileOf<[I8]>;
+def AMXTileF8E4M3FN : AMXTileOf<[F8E4M3FN]>;
+
+def AMXTileF8E5M2 : AMXTileOf<[F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// AMX Op definitions
//===----------------------------------------------------------------------===//
@@ -961,10 +965,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
- pairs of "bf16").
+ pairs of "bf16") and "f32 <- f8E5M2/f8E4M3FN x f8E5M2/f8E4M3FN".
- The operation is eventually lowered into the "tdpbf16ps" instruction with
- the corresponding tile configuration.
+ The operation is eventually lowered into the "tdpbf16ps/tdpbf8ps/tdpbhf8ps/
+ tdphbf8ps/tdphf8ps" instruction with the corresponding tile configuration.
Example:
@@ -973,8 +977,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
```
}];
- let arguments = (ins AMXTileF16OrBF16:$lhs,
- AMXTileF16OrBF16:$rhs,
+
+ let arguments = (ins AMXTileF16OrBF16OrF8:$lhs,
+ AMXTileF16OrBF16OrF8:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
@@ -992,7 +997,20 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
std::string intr = "llvm.x86.tdp";
auto elementType =
getLhsTileType().getElementType();
- intr += elementType.isF16() ? "fp16" : "bf16";
+ auto elementTypeRhs =
+ getRhsTileType().getElementType();
+
+ intr += elementType.isF16() ? "fp16" :
+ elementType.isBF16() ? "bf16" :
+ (elementType.isF8E4M3FN() || elementType.isF8E5M2())
+ ? (
+ (elementType.isF8E4M3FN() ? "b" : "h") +
+ std::string(elementType != elementTypeRhs
+ ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
+ : "") +
+ "f8"
+ )
+ : "";
intr += "ps.internal";
return intr;
}
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index b186652aaa866..2ef061a27ddad 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -295,15 +295,22 @@ LogicalResult x86::amx::TileMulFOp::verify() {
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
x86::amx::TileType cType = getTileType();
+ unsigned scale = 1;
+ if (aType.getElementType().isF8E4M3FN() || aType.getElementType().isF8E5M2())
+ scale = 2;
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
- failed(verifyMultShape(*this, aType, bType, cType, 1)))
+ failed(verifyMultShape(*this, aType, bType, cType, scale)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
- if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
+ bool flag1 = !ta.isBF16() && !ta.isF16() &&
+ !((ta.isF8E4M3FN() || ta.isF8E5M2()) &&
+ (tb.isF8E4M3FN() || tb.isF8E5M2()));
+ bool flag2 = (ta.isBF16() || ta.isF16()) ? (ta != tb) : false;
+ if (flag1 || flag2 || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
diff --git a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
index eb12e20b699b3..3b45135be35a7 100644
--- a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
@@ -60,6 +60,70 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
return
}
+// CHECK-LABEL: mulf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E4M3FNxf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbhf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FNxf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E5M2xf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2xf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
/// Intrinsics require stride in number of bytes.
// CHECK-LABEL: strides_implicit(
// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
More information about the Mlir-commits
mailing list