[Mlir-commits] [mlir] [mlir][x86] Support for `f8` AMX tiled dot-product. (PR #194786)

Arun Thangamani llvmlistbot at llvm.org
Tue Apr 28 22:24:30 PDT 2026


https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/194786

This patch enable AMX tiled dot-product support for `f8E4M3FN` and `f8E5M2` types in MLIR by lowering to below llvm instrincs:

- `llvm.x86.tdpbf8ps`
- `llvm.x86.tdpbhf8ps`
- `llvm.x86.tdphbf8ps`
- `llvm.x86.tdphf8ps`

>From 414d0fb2ef48315204a8edda1d28a886f0b186dd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 28 Apr 2026 22:14:58 -0700
Subject: [PATCH] mlir support for f8 type lowering to llvm instrincs.

---
 mlir/include/mlir/Dialect/X86/X86.td          | 36 ++++++++---
 mlir/lib/Dialect/X86/IR/X86Dialect.cpp        | 11 +++-
 .../Dialect/X86/AMX/legalize-for-llvm.mlir    | 64 +++++++++++++++++++
 3 files changed, 100 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 814bf884395bb..96b72c2e8c215 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -694,7 +694,7 @@ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = "amx." # typeMnemonic;
 }
 
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]> {
   let cppFunctionName = "isValidTileTypeElementType";
 }
 
@@ -743,16 +743,20 @@ class AMXTileOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
                       "::mlir::x86::AMXTileType">;
 
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]>;
 
 def AMXTileF32 : AMXTileOf<[F32]>;
 
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+def AMXTileF16OrBF16OrF8 : AMXTileOf<[F16, BF16, F8E4M3FN, F8E5M2]>;
 
 def AMXTileI32 : AMXTileOf<[I32]>;
 
 def AMXTileI8 : AMXTileOf<[I8]>;
 
+def AMXTileF8E4M3FN : AMXTileOf<[F8E4M3FN]>;
+
+def AMXTileF8E5M2 : AMXTileOf<[F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // AMX Op definitions
 //===----------------------------------------------------------------------===//
@@ -961,10 +965,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
     into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16").
+    pairs of "bf16") and "f32 <- f8E5M2/f8E4M3FN x f8E5M2/f8E4M3FN".
     
-    The operation is eventually lowered into the "tdpbf16ps" instruction with
-    the corresponding tile configuration.
+    The operation is eventually lowered into the "tdpbf16ps/tdpbf8ps/tdpbhf8ps/
+    tdphbf8ps/tdphf8ps" instruction with the corresponding tile configuration.
 
     Example:
 
@@ -973,8 +977,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
         : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
     ```
   }];
-  let arguments = (ins AMXTileF16OrBF16:$lhs,
-                       AMXTileF16OrBF16:$rhs,
+
+  let arguments = (ins AMXTileF16OrBF16OrF8:$lhs,
+                       AMXTileF16OrBF16OrF8:$rhs,
                        AMXTileF32:$acc);
   let results = (outs AMXTileF32:$res);
   let extraClassDeclaration = [{
@@ -992,7 +997,20 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
       std::string intr = "llvm.x86.tdp";
       auto elementType =
         getLhsTileType().getElementType();
-      intr += elementType.isF16() ? "fp16" : "bf16";
+      auto elementTypeRhs =
+        getRhsTileType().getElementType();
+
+      intr += elementType.isF16()  ? "fp16" :
+              elementType.isBF16() ? "bf16" :
+              (elementType.isF8E4M3FN() || elementType.isF8E5M2())
+                  ? (
+                        (elementType.isF8E4M3FN() ? "b" : "h") +
+                        std::string(elementType != elementTypeRhs
+                                        ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
+                                        : "") +
+                        "f8"
+                    )
+                  : "";
       intr += "ps.internal";
       return intr;
     }
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index b186652aaa866..2ef061a27ddad 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -295,15 +295,22 @@ LogicalResult x86::amx::TileMulFOp::verify() {
   x86::amx::TileType aType = getLhsTileType();
   x86::amx::TileType bType = getRhsTileType();
   x86::amx::TileType cType = getTileType();
+  unsigned scale = 1;
+  if (aType.getElementType().isF8E4M3FN() || aType.getElementType().isF8E5M2())
+    scale = 2;
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
-      failed(verifyMultShape(*this, aType, bType, cType, 1)))
+      failed(verifyMultShape(*this, aType, bType, cType, scale)))
     return failure();
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
+  bool flag1 = !ta.isBF16() && !ta.isF16() &&
+               !((ta.isF8E4M3FN() || ta.isF8E5M2()) &&
+                 (tb.isF8E4M3FN() || tb.isF8E5M2()));
+  bool flag2 = (ta.isBF16() || ta.isF16()) ? (ta != tb) : false;
+  if (flag1 || flag2 || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
diff --git a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
index eb12e20b699b3..3b45135be35a7 100644
--- a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
@@ -60,6 +60,70 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   return
 }
 
+// CHECK-LABEL: mulf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulf8E4M3FNxf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbhf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FNxf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulf8E5M2xf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2xf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}	
+
+// CHECK-LABEL: mulf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
 /// Intrinsics require stride in number of bytes.
 // CHECK-LABEL: strides_implicit(
 // CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64



More information about the Mlir-commits mailing list