[Mlir-commits] [mlir] [MLIR][Arith] Fix arith-emulate-unsupported-floats for extf on emulated float types (PR #189243)
Mehdi Amini
llvmlistbot at llvm.org
Sun Mar 29 06:20:28 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189243
The ArithEmulateUnsupportedFloats pass emulates arithmetic ops on unsupported float types (e.g. fp8 variants) by widening to a supported target type. Previously, when an emulated op's result was immediately used by an arith.extf back to the target type, the pass would emit a truncf/extf round-trip (f32->fp8->f32) that left an arith.extf on the unsupported type in the output. This arith.extf could not be lowered by convert-arith-to-llvm for types without LLVM hardware support (e.g. f8E4M3FNUZ), because the type converter maps fp8 to i8, causing convert-arith-to-llvm to generate invalid llvm.fpext i8->f32 operations.
Instead when all uses of an emulated op's result are arith.extf ops extending from the unsupported type to the target type, skip the intermediate truncf round-trip and replace the extf users directly with the wider emulated value. This eliminates the arith.extf on the unsupported type from the output, making the pipeline compatible with convert-arith-to-llvm.
Fixes #152287
Assisted-by: Claude Code
>From d93deb94c962e0e296aa75a760f2bf967c4c7b78 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 18:03:36 -0700
Subject: [PATCH] [MLIR][Arith] Fix arith-emulate-unsupported-floats for extf
on emulated float types
The ArithEmulateUnsupportedFloats pass emulates arithmetic ops on
unsupported float types (e.g. fp8 variants) by widening to a supported
target type. Previously, when an emulated op's result was immediately used
by an arith.extf back to the target type, the pass would emit a
truncf/extf round-trip (f32->fp8->f32) that left an arith.extf on the
unsupported type in the output. This arith.extf could not be lowered by
convert-arith-to-llvm for types without LLVM hardware support (e.g.
f8E4M3FNUZ), because the type converter maps fp8 to i8, causing
convert-arith-to-llvm to generate invalid llvm.fpext i8->f32 operations.
The fix: in EmulateFloatPattern, when all uses of an emulated op's result
are arith.extf ops extending from the unsupported type to the target type,
skip the intermediate truncf round-trip and replace the extf users directly
with the wider emulated value. This eliminates the arith.extf on the
unsupported type from the output, making the pipeline compatible with
convert-arith-to-llvm.
Fixes #152287
Assisted-by: Claude Code
---
.../Transforms/EmulateUnsupportedFloats.cpp | 31 ++++++++++++++++---
.../Arith/emulate-unsupported-floats.mlir | 30 ++++++++++++++++--
2 files changed, 53 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index b6e101952676a..f2ff3cf64be23 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -72,12 +72,33 @@ LogicalResult EmulateFloatPattern::matchAndRewrite(
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
op->getAttrs(), op->getSuccessors(), /*regions=*/{});
SmallVector<Value> newResults(expandedOp->getResults());
- for (auto [res, oldType, newType] : llvm::zip_equal(
- MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
+ for (auto [res, origResult, oldType, newType] :
+ llvm::zip_equal(MutableArrayRef{newResults}, op->getResults(),
+ op->getResultTypes(), resultTypes)) {
if (oldType != newType) {
- auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- res = truncFOp.getResult();
+ // If all uses of the original result are arith.extf ops that extend from
+ // the unsupported type to the target (wider) type, we can skip the
+ // intermediate truncf round-trip and directly replace those extf ops with
+ // the wider emulated value. This avoids emitting arith.extf on the
+ // unsupported type in the output, which cannot be lowered to LLVM for
+ // types that lack native hardware support (e.g. fp8 variants).
+ bool allUsersAreExtFToTargetType =
+ !origResult.use_empty() &&
+ llvm::all_of(origResult.getUsers(), [newType](Operation *user) {
+ auto extFOp = dyn_cast<arith::ExtFOp>(user);
+ return extFOp && extFOp.getType() == newType;
+ });
+ if (allUsersAreExtFToTargetType) {
+ // Replace all extf users directly with the wider emulated value.
+ for (Operation *user :
+ llvm::make_early_inc_range(origResult.getUsers()))
+ rewriter.replaceOp(user, res);
+ // No truncf needed; res already has the target (wider) type.
+ } else {
+ auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
+ truncFOp.setFastmath(arith::FastMathFlags::contract);
+ res = truncFOp.getResult();
+ }
}
}
rewriter.replaceOp(op, newResults);
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index fcd004ac554aa..00fb6d282f4a6 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=f8E4M3FNUZ target-type=f32" --convert-arith-to-llvm %s | FileCheck %s --check-prefix=LLVM
func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
@@ -60,14 +61,20 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
// -----
+// When the result of an emulated op is only used by an extf back to the
+// target type, the pass skips the truncf/extf round-trip and uses the
+// wider emulated value directly. This avoids emitting arith.extf on the
+// unsupported type, which cannot be lowered to LLVM.
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// CHECK-LABEL: @vectors
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
-// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
-// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[RET:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
+// CHECK-NOT: arith.truncf
+// CHECK-NOT: arith.extf {{%.+}} : vector<4xf8E4M3FNUZ>
// CHECK: return [[RET]]
+// LLVM-LABEL: @vectors
+// LLVM-NOT: llvm.fpext {{.*}} : vector<4xi8>
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
%ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
func.return %ret : vector<4xf32>
@@ -75,6 +82,23 @@ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
// -----
+// When an emulated op's result has mixed users (not all are arith.extf to the
+// target type), the pass falls back to the truncf/extf round-trip.
+func.func @mixed_users(%a: bf16) -> (f32, bf16) {
+// CHECK-LABEL: @mixed_users
+// CHECK-SAME: [[A:%.+]]: bf16
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : bf16 to f32
+// CHECK: [[PROD:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : f32
+// CHECK: [[TRUNC:%.+]] = arith.truncf [[PROD]] fastmath<contract> : f32 to bf16
+// CHECK: [[EXT:%.+]] = arith.extf [[TRUNC]] : bf16 to f32
+// CHECK: return [[EXT]], [[TRUNC]]
+ %b = arith.mulf %a, %a : bf16
+ %ext = arith.extf %b : bf16 to f32
+ func.return %ext, %b : f32, bf16
+}
+
+// -----
+
func.func @no_expansion(%x: f32) -> f32 {
// CHECK-LABEL: @no_expansion
// CHECK-SAME: [[X:%.+]]: f32
More information about the Mlir-commits
mailing list