[Mlir-commits] [mlir] [mlir][x86] Support for `f8` AMX tiled dot-product. (PR #194786)

Arun Thangamani llvmlistbot at llvm.org
Wed Apr 29 04:07:18 PDT 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/194786

>From 414d0fb2ef48315204a8edda1d28a886f0b186dd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 28 Apr 2026 22:14:58 -0700
Subject: [PATCH 1/3] mlir support for f8 type lowering to llvm instrincs.

---
 mlir/include/mlir/Dialect/X86/X86.td          | 36 ++++++++---
 mlir/lib/Dialect/X86/IR/X86Dialect.cpp        | 11 +++-
 .../Dialect/X86/AMX/legalize-for-llvm.mlir    | 64 +++++++++++++++++++
 3 files changed, 100 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 814bf884395bb..96b72c2e8c215 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -694,7 +694,7 @@ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = "amx." # typeMnemonic;
 }
 
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]> {
   let cppFunctionName = "isValidTileTypeElementType";
 }
 
@@ -743,16 +743,20 @@ class AMXTileOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
                       "::mlir::x86::AMXTileType">;
 
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]>;
 
 def AMXTileF32 : AMXTileOf<[F32]>;
 
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+def AMXTileF16OrBF16OrF8 : AMXTileOf<[F16, BF16, F8E4M3FN, F8E5M2]>;
 
 def AMXTileI32 : AMXTileOf<[I32]>;
 
 def AMXTileI8 : AMXTileOf<[I8]>;
 
+def AMXTileF8E4M3FN : AMXTileOf<[F8E4M3FN]>;
+
+def AMXTileF8E5M2 : AMXTileOf<[F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // AMX Op definitions
 //===----------------------------------------------------------------------===//
@@ -961,10 +965,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
     into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16").
+    pairs of "bf16") and "f32 <- f8E5M2/f8E4M3FN x f8E5M2/f8E4M3FN".
     
-    The operation is eventually lowered into the "tdpbf16ps" instruction with
-    the corresponding tile configuration.
+    The operation is eventually lowered into the "tdpbf16ps/tdpbf8ps/tdpbhf8ps/
+    tdphbf8ps/tdphf8ps" instruction with the corresponding tile configuration.
 
     Example:
 
@@ -973,8 +977,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
         : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
     ```
   }];
-  let arguments = (ins AMXTileF16OrBF16:$lhs,
-                       AMXTileF16OrBF16:$rhs,
+
+  let arguments = (ins AMXTileF16OrBF16OrF8:$lhs,
+                       AMXTileF16OrBF16OrF8:$rhs,
                        AMXTileF32:$acc);
   let results = (outs AMXTileF32:$res);
   let extraClassDeclaration = [{
@@ -992,7 +997,20 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
       std::string intr = "llvm.x86.tdp";
       auto elementType =
         getLhsTileType().getElementType();
-      intr += elementType.isF16() ? "fp16" : "bf16";
+      auto elementTypeRhs =
+        getRhsTileType().getElementType();
+
+      intr += elementType.isF16()  ? "fp16" :
+              elementType.isBF16() ? "bf16" :
+              (elementType.isF8E4M3FN() || elementType.isF8E5M2())
+                  ? (
+                        (elementType.isF8E4M3FN() ? "b" : "h") +
+                        std::string(elementType != elementTypeRhs
+                                        ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
+                                        : "") +
+                        "f8"
+                    )
+                  : "";
       intr += "ps.internal";
       return intr;
     }
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index b186652aaa866..2ef061a27ddad 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -295,15 +295,22 @@ LogicalResult x86::amx::TileMulFOp::verify() {
   x86::amx::TileType aType = getLhsTileType();
   x86::amx::TileType bType = getRhsTileType();
   x86::amx::TileType cType = getTileType();
+  unsigned scale = 1;
+  if (aType.getElementType().isF8E4M3FN() || aType.getElementType().isF8E5M2())
+    scale = 2;
   if (failed(verifyTileSize(*this, aType)) ||
       failed(verifyTileSize(*this, bType)) ||
       failed(verifyTileSize(*this, cType)) ||
-      failed(verifyMultShape(*this, aType, bType, cType, 1)))
+      failed(verifyMultShape(*this, aType, bType, cType, scale)))
     return failure();
   Type ta = aType.getElementType();
   Type tb = bType.getElementType();
   Type tc = cType.getElementType();
-  if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
+  bool flag1 = !ta.isBF16() && !ta.isF16() &&
+               !((ta.isF8E4M3FN() || ta.isF8E5M2()) &&
+                 (tb.isF8E4M3FN() || tb.isF8E5M2()));
+  bool flag2 = (ta.isBF16() || ta.isF16()) ? (ta != tb) : false;
+  if (flag1 || flag2 || !tc.isF32())
     return emitOpError("unsupported type combination");
   return success();
 }
diff --git a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
index eb12e20b699b3..3b45135be35a7 100644
--- a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
@@ -60,6 +60,70 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   return
 }
 
+// CHECK-LABEL: mulf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulf8E4M3FNxf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbhf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FNxf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
+// CHECK-LABEL: mulf8E5M2xf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2xf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}	
+
+// CHECK-LABEL: mulf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+  %0 = arith.constant 0 : index
+  %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+  %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+  %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
 /// Intrinsics require stride in number of bytes.
 // CHECK-LABEL: strides_implicit(
 // CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64

>From 0d6a04e60156ad089600c40dc798c6a7769e957b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 29 Apr 2026 03:06:00 -0700
Subject: [PATCH 2/3] code refactoring + few more test checks.

---
 mlir/include/mlir/Dialect/X86/X86.td | 23 ++++++++++---------
 mlir/test/Target/LLVMIR/amx.mlir     | 34 ++++++++++++++++++++++++++++
 2 files changed, 46 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 96b72c2e8c215..a152b5e7ae6db 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -1000,20 +1000,21 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
       auto elementTypeRhs =
         getRhsTileType().getElementType();
 
-      intr += elementType.isF16()  ? "fp16" :
-              elementType.isBF16() ? "bf16" :
-              (elementType.isF8E4M3FN() || elementType.isF8E5M2())
-                  ? (
-                        (elementType.isF8E4M3FN() ? "b" : "h") +
-                        std::string(elementType != elementTypeRhs
-                                        ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
-                                        : "") +
-                        "f8"
-                    )
-                  : "";
+      if (elementType.isF16()) {
+        intr += "fp16";
+      } else if (elementType.isBF16()) {
+        intr += "bf16";
+      } else if (elementType.isF8E4M3FN() || elementType.isF8E5M2()) {
+        intr += elementType.isF8E4M3FN() ? "b" : "h";
+        if (elementType != elementTypeRhs)
+          intr += elementTypeRhs.isF8E4M3FN() ? "b" : "h";
+        intr += "f8";
+      }
+
       intr += "ps.internal";
       return intr;
     }
+
     SmallVector<Value> getIntrinsicOperands(
         ::mlir::ArrayRef<Value> operands,
         const ::mlir::LLVMTypeConverter &typeConverter,
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 4a4be24c2e3ab..df8645127ba37 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -72,6 +72,40 @@ func.func @amx_tile_mulf_f16(
   return
 }
 
+// CHECK-LABEL: define void @amx_tile_mulf_f8
+func.func @amx_tile_mulf_f8E4M3FN(
+    %matA: memref<?x?xf8E4M3FN>, %matB: memref<?x?xf8E5M2>, %idx: index,
+    %out: memref<?x?xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+  // CHECK-COUNT-4: call x86_amx @llvm.x86.tileloadd64.internal
+  %tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  %tA1 = x86.amx.tile_load %matA[%c0, %c0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+  %tB1 = x86.amx.tile_load %matB[%c0, %c0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+  // CHECK: call x86_amx @llvm.x86.tdpbf8ps.internal
+  // CHECK: call x86_amx @llvm.x86.tdpbhf8ps.internal
+  // CHECK: call x86_amx @llvm.x86.tdphbf8ps.internal
+  // CHECK: call x86_amx @llvm.x86.tdphf8ps.internal
+  %tRes = x86.amx.tile_mulf %tA, %tA1, %acc
+    : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  %tRes1 = x86.amx.tile_mulf %tA, %tB, %acc
+      : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  %tRes2 = x86.amx.tile_mulf %tB, %tA, %acc
+      : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+  %tRes3 = x86.amx.tile_mulf %tB, %tB1, %acc
+    : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+  // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
+  x86.amx.tile_store %out[%idx, %c0], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%idx, %c16], %tRes1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%c0, %idx], %tRes2 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  x86.amx.tile_store %out[%c16, %idx], %tRes3 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+  return
+}
+
 // CHECK-LABEL: define void @amx_tile_muli
 func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
     %matC: memref<?x?xi32>, %idx: index, %out: memref<?x?xi8>)

>From 075484898701440dc37ace27ae0ed0e4a9742a3f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 29 Apr 2026 04:06:57 -0700
Subject: [PATCH 3/3] cleanup: correct naming

---
 mlir/test/Target/LLVMIR/amx.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index df8645127ba37..be74dc199af72 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -73,7 +73,7 @@ func.func @amx_tile_mulf_f16(
 }
 
 // CHECK-LABEL: define void @amx_tile_mulf_f8
-func.func @amx_tile_mulf_f8E4M3FN(
+func.func @amx_tile_mulf_f8(
     %matA: memref<?x?xf8E4M3FN>, %matB: memref<?x?xf8E5M2>, %idx: index,
     %out: memref<?x?xf32>)
 {



More information about the Mlir-commits mailing list