[Mlir-commits] [mlir] [mlir][x86] Support for `f8` AMX tiled dot-product. (PR #194786)
Arun Thangamani
llvmlistbot at llvm.org
Wed Apr 29 04:07:18 PDT 2026
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/194786
>From 414d0fb2ef48315204a8edda1d28a886f0b186dd Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Tue, 28 Apr 2026 22:14:58 -0700
Subject: [PATCH 1/3] mlir support for f8 type lowering to llvm instrincs.
---
mlir/include/mlir/Dialect/X86/X86.td | 36 ++++++++---
mlir/lib/Dialect/X86/IR/X86Dialect.cpp | 11 +++-
.../Dialect/X86/AMX/legalize-for-llvm.mlir | 64 +++++++++++++++++++
3 files changed, 100 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 814bf884395bb..96b72c2e8c215 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -694,7 +694,7 @@ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
let mnemonic = "amx." # typeMnemonic;
}
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]> {
let cppFunctionName = "isValidTileTypeElementType";
}
@@ -743,16 +743,20 @@ class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::x86::AMXTileType">;
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
+def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8, F8E4M3FN, F8E5M2]>;
def AMXTileF32 : AMXTileOf<[F32]>;
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
+def AMXTileF16OrBF16OrF8 : AMXTileOf<[F16, BF16, F8E4M3FN, F8E5M2]>;
def AMXTileI32 : AMXTileOf<[I32]>;
def AMXTileI8 : AMXTileOf<[I8]>;
+def AMXTileF8E4M3FN : AMXTileOf<[F8E4M3FN]>;
+
+def AMXTileF8E5M2 : AMXTileOf<[F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// AMX Op definitions
//===----------------------------------------------------------------------===//
@@ -961,10 +965,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
- pairs of "bf16").
+ pairs of "bf16") and "f32 <- f8E5M2/f8E4M3FN x f8E5M2/f8E4M3FN".
- The operation is eventually lowered into the "tdpbf16ps" instruction with
- the corresponding tile configuration.
+ The operation is eventually lowered into the "tdpbf16ps/tdpbf8ps/tdpbhf8ps/
+ tdphbf8ps/tdphf8ps" instruction with the corresponding tile configuration.
Example:
@@ -973,8 +977,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
: !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32>
```
}];
- let arguments = (ins AMXTileF16OrBF16:$lhs,
- AMXTileF16OrBF16:$rhs,
+
+ let arguments = (ins AMXTileF16OrBF16OrF8:$lhs,
+ AMXTileF16OrBF16OrF8:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
@@ -992,7 +997,20 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
std::string intr = "llvm.x86.tdp";
auto elementType =
getLhsTileType().getElementType();
- intr += elementType.isF16() ? "fp16" : "bf16";
+ auto elementTypeRhs =
+ getRhsTileType().getElementType();
+
+ intr += elementType.isF16() ? "fp16" :
+ elementType.isBF16() ? "bf16" :
+ (elementType.isF8E4M3FN() || elementType.isF8E5M2())
+ ? (
+ (elementType.isF8E4M3FN() ? "b" : "h") +
+ std::string(elementType != elementTypeRhs
+ ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
+ : "") +
+ "f8"
+ )
+ : "";
intr += "ps.internal";
return intr;
}
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index b186652aaa866..2ef061a27ddad 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -295,15 +295,22 @@ LogicalResult x86::amx::TileMulFOp::verify() {
x86::amx::TileType aType = getLhsTileType();
x86::amx::TileType bType = getRhsTileType();
x86::amx::TileType cType = getTileType();
+ unsigned scale = 1;
+ if (aType.getElementType().isF8E4M3FN() || aType.getElementType().isF8E5M2())
+ scale = 2;
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
- failed(verifyMultShape(*this, aType, bType, cType, 1)))
+ failed(verifyMultShape(*this, aType, bType, cType, scale)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
- if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
+ bool flag1 = !ta.isBF16() && !ta.isF16() &&
+ !((ta.isF8E4M3FN() || ta.isF8E5M2()) &&
+ (tb.isF8E4M3FN() || tb.isF8E5M2()));
+ bool flag2 = (ta.isBF16() || ta.isF16()) ? (ta != tb) : false;
+ if (flag1 || flag2 || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
diff --git a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
index eb12e20b699b3..3b45135be35a7 100644
--- a/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir
@@ -60,6 +60,70 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
return
}
+// CHECK-LABEL: mulf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E4M3FNxf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdpbhf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E4M3FNxf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E4M3FN>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E5M2xf8E4M3FN(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphbf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2xf8E4M3FN(%arg0: memref<?x?xf8E4M3FN>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
+// CHECK-LABEL: mulf8E5M2(
+// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tdphf8ps.internal"
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"
+func.func @mulf8E5M2(%arg0: memref<?x?xf8E5M2>, %arg1: memref<?x?xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xf8E5M2>
+ %2 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ %3 = x86.amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
+ %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
/// Intrinsics require stride in number of bytes.
// CHECK-LABEL: strides_implicit(
// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
>From 0d6a04e60156ad089600c40dc798c6a7769e957b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 29 Apr 2026 03:06:00 -0700
Subject: [PATCH 2/3] code refactoring + few more test checks.
---
mlir/include/mlir/Dialect/X86/X86.td | 23 ++++++++++---------
mlir/test/Target/LLVMIR/amx.mlir | 34 ++++++++++++++++++++++++++++
2 files changed, 46 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index 96b72c2e8c215..a152b5e7ae6db 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -1000,20 +1000,21 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
auto elementTypeRhs =
getRhsTileType().getElementType();
- intr += elementType.isF16() ? "fp16" :
- elementType.isBF16() ? "bf16" :
- (elementType.isF8E4M3FN() || elementType.isF8E5M2())
- ? (
- (elementType.isF8E4M3FN() ? "b" : "h") +
- std::string(elementType != elementTypeRhs
- ? (elementTypeRhs.isF8E4M3FN() ? "b" : "h")
- : "") +
- "f8"
- )
- : "";
+ if (elementType.isF16()) {
+ intr += "fp16";
+ } else if (elementType.isBF16()) {
+ intr += "bf16";
+ } else if (elementType.isF8E4M3FN() || elementType.isF8E5M2()) {
+ intr += elementType.isF8E4M3FN() ? "b" : "h";
+ if (elementType != elementTypeRhs)
+ intr += elementTypeRhs.isF8E4M3FN() ? "b" : "h";
+ intr += "f8";
+ }
+
intr += "ps.internal";
return intr;
}
+
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 4a4be24c2e3ab..df8645127ba37 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -72,6 +72,40 @@ func.func @amx_tile_mulf_f16(
return
}
+// CHECK-LABEL: define void @amx_tile_mulf_f8
+func.func @amx_tile_mulf_f8E4M3FN(
+ %matA: memref<?x?xf8E4M3FN>, %matB: memref<?x?xf8E5M2>, %idx: index,
+ %out: memref<?x?xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+ %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32>
+ // CHECK-COUNT-4: call x86_amx @llvm.x86.tileloadd64.internal
+ %tA = x86.amx.tile_load %matA[%idx, %idx] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %tB = x86.amx.tile_load %matB[%idx, %idx] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ %tA1 = x86.amx.tile_load %matA[%c0, %c0] : memref<?x?xf8E4M3FN> into !x86.amx.tile<16x64xf8E4M3FN>
+ %tB1 = x86.amx.tile_load %matB[%c0, %c0] : memref<?x?xf8E5M2> into !x86.amx.tile<16x64xf8E5M2>
+ // CHECK: call x86_amx @llvm.x86.tdpbf8ps.internal
+ // CHECK: call x86_amx @llvm.x86.tdpbhf8ps.internal
+ // CHECK: call x86_amx @llvm.x86.tdphbf8ps.internal
+ // CHECK: call x86_amx @llvm.x86.tdphf8ps.internal
+ %tRes = x86.amx.tile_mulf %tA, %tA1, %acc
+ : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ %tRes1 = x86.amx.tile_mulf %tA, %tB, %acc
+ : !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ %tRes2 = x86.amx.tile_mulf %tB, %tA, %acc
+ : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E4M3FN>, !x86.amx.tile<16x16xf32>
+ %tRes3 = x86.amx.tile_mulf %tB, %tB1, %acc
+ : !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x64xf8E5M2>, !x86.amx.tile<16x16xf32>
+ // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal
+ x86.amx.tile_store %out[%idx, %c0], %tRes : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %out[%idx, %c16], %tRes1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %out[%c0, %idx], %tRes2 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ x86.amx.tile_store %out[%c16, %idx], %tRes3 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
+ return
+}
+
// CHECK-LABEL: define void @amx_tile_muli
func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
%matC: memref<?x?xi32>, %idx: index, %out: memref<?x?xi8>)
>From 075484898701440dc37ace27ae0ed0e4a9742a3f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Wed, 29 Apr 2026 04:06:57 -0700
Subject: [PATCH 3/3] cleanup: correct naming
---
mlir/test/Target/LLVMIR/amx.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index df8645127ba37..be74dc199af72 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -73,7 +73,7 @@ func.func @amx_tile_mulf_f16(
}
// CHECK-LABEL: define void @amx_tile_mulf_f8
-func.func @amx_tile_mulf_f8E4M3FN(
+func.func @amx_tile_mulf_f8(
%matA: memref<?x?xf8E4M3FN>, %matB: memref<?x?xf8E5M2>, %idx: index,
%out: memref<?x?xf32>)
{
More information about the Mlir-commits
mailing list