[Mlir-commits] [mlir] [MLIR][Arith] Fix arith-emulate-unsupported-floats for extf on emulated float types (PR #189243)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 29 06:20:28 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189243

The ArithEmulateUnsupportedFloats pass emulates arithmetic ops on unsupported float types (e.g. fp8 variants) by widening to a supported target type. Previously, when an emulated op's result was immediately used by an arith.extf back to the target type, the pass would emit a truncf/extf round-trip (f32->fp8->f32) that left an arith.extf on the unsupported type in the output. This arith.extf could not be lowered by convert-arith-to-llvm for types without LLVM hardware support (e.g. f8E4M3FNUZ), because the type converter maps fp8 to i8, causing convert-arith-to-llvm to generate invalid llvm.fpext i8->f32 operations.

Instead when all uses of an emulated op's result are arith.extf ops extending from the unsupported type to the target type, skip the intermediate truncf round-trip and replace the extf users directly with the wider emulated value. This eliminates the arith.extf on the unsupported type from the output, making the pipeline compatible with convert-arith-to-llvm.

Fixes #152287

Assisted-by: Claude Code

>From d93deb94c962e0e296aa75a760f2bf967c4c7b78 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 18:03:36 -0700
Subject: [PATCH] [MLIR][Arith] Fix arith-emulate-unsupported-floats for extf
 on emulated float types

The ArithEmulateUnsupportedFloats pass emulates arithmetic ops on
unsupported float types (e.g. fp8 variants) by widening to a supported
target type. Previously, when an emulated op's result was immediately used
by an arith.extf back to the target type, the pass would emit a
truncf/extf round-trip (f32->fp8->f32) that left an arith.extf on the
unsupported type in the output. This arith.extf could not be lowered by
convert-arith-to-llvm for types without LLVM hardware support (e.g.
f8E4M3FNUZ), because the type converter maps fp8 to i8, causing
convert-arith-to-llvm to generate invalid llvm.fpext i8->f32 operations.

The fix: in EmulateFloatPattern, when all uses of an emulated op's result
are arith.extf ops extending from the unsupported type to the target type,
skip the intermediate truncf round-trip and replace the extf users directly
with the wider emulated value. This eliminates the arith.extf on the
unsupported type from the output, making the pipeline compatible with
convert-arith-to-llvm.

Fixes #152287

Assisted-by: Claude Code
---
 .../Transforms/EmulateUnsupportedFloats.cpp   | 31 ++++++++++++++++---
 .../Arith/emulate-unsupported-floats.mlir     | 30 ++++++++++++++++--
 2 files changed, 53 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index b6e101952676a..f2ff3cf64be23 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -72,12 +72,33 @@ LogicalResult EmulateFloatPattern::matchAndRewrite(
       rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
                       op->getAttrs(), op->getSuccessors(), /*regions=*/{});
   SmallVector<Value> newResults(expandedOp->getResults());
-  for (auto [res, oldType, newType] : llvm::zip_equal(
-           MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
+  for (auto [res, origResult, oldType, newType] :
+       llvm::zip_equal(MutableArrayRef{newResults}, op->getResults(),
+                       op->getResultTypes(), resultTypes)) {
     if (oldType != newType) {
-      auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      res = truncFOp.getResult();
+      // If all uses of the original result are arith.extf ops that extend from
+      // the unsupported type to the target (wider) type, we can skip the
+      // intermediate truncf round-trip and directly replace those extf ops with
+      // the wider emulated value. This avoids emitting arith.extf on the
+      // unsupported type in the output, which cannot be lowered to LLVM for
+      // types that lack native hardware support (e.g. fp8 variants).
+      bool allUsersAreExtFToTargetType =
+          !origResult.use_empty() &&
+          llvm::all_of(origResult.getUsers(), [newType](Operation *user) {
+            auto extFOp = dyn_cast<arith::ExtFOp>(user);
+            return extFOp && extFOp.getType() == newType;
+          });
+      if (allUsersAreExtFToTargetType) {
+        // Replace all extf users directly with the wider emulated value.
+        for (Operation *user :
+             llvm::make_early_inc_range(origResult.getUsers()))
+          rewriter.replaceOp(user, res);
+        // No truncf needed; res already has the target (wider) type.
+      } else {
+        auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
+        truncFOp.setFastmath(arith::FastMathFlags::contract);
+        res = truncFOp.getResult();
+      }
     }
   }
   rewriter.replaceOp(op, newResults);
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index fcd004ac554aa..00fb6d282f4a6 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=f8E4M3FNUZ target-type=f32" --convert-arith-to-llvm %s | FileCheck %s --check-prefix=LLVM
 
 func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK-LABEL: @basic_expansion
@@ -60,14 +61,20 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
 
 // -----
 
+// When the result of an emulated op is only used by an extf back to the
+// target type, the pass skips the truncf/extf round-trip and uses the
+// wider emulated value directly. This avoids emitting arith.extf on the
+// unsupported type, which cannot be lowered to LLVM.
 func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
 // CHECK-LABEL: @vectors
 // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
 // CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
-// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
-// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
-// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
+// CHECK: [[RET:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
+// CHECK-NOT: arith.truncf
+// CHECK-NOT: arith.extf {{%.+}} : vector<4xf8E4M3FNUZ>
 // CHECK: return [[RET]]
+// LLVM-LABEL: @vectors
+// LLVM-NOT: llvm.fpext {{.*}} : vector<4xi8>
   %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>
   %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32>
   func.return %ret : vector<4xf32>
@@ -75,6 +82,23 @@ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
 
 // -----
 
+// When an emulated op's result has mixed users (not all are arith.extf to the
+// target type), the pass falls back to the truncf/extf round-trip.
+func.func @mixed_users(%a: bf16) -> (f32, bf16) {
+// CHECK-LABEL: @mixed_users
+// CHECK-SAME: [[A:%.+]]: bf16
+// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : bf16 to f32
+// CHECK: [[PROD:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : f32
+// CHECK: [[TRUNC:%.+]] = arith.truncf [[PROD]] fastmath<contract> : f32 to bf16
+// CHECK: [[EXT:%.+]] = arith.extf [[TRUNC]] : bf16 to f32
+// CHECK: return [[EXT]], [[TRUNC]]
+  %b = arith.mulf %a, %a : bf16
+  %ext = arith.extf %b : bf16 to f32
+  func.return %ext, %b : f32, bf16
+}
+
+// -----
+
 func.func @no_expansion(%x: f32) -> f32 {
 // CHECK-LABEL: @no_expansion
 // CHECK-SAME: [[X:%.+]]: f32



More information about the Mlir-commits mailing list