[Mlir-commits] [mlir] [mlir][tosa] Optimize non-narrowing float casts (PR #191439)

Ian Tayler Lessa llvmlistbot at llvm.org
Fri Apr 10 07:55:30 PDT 2026


https://github.com/IanTaylerLessa-arm created https://github.com/llvm/llvm-project/pull/191439

Extend the existing NonNarrowingCastsOptimization to also cover casts between floating point types f32, f16, bf16, f8E4M3FN and F8E5M2. Avoid introducing direct casts between f8 types since those are not allowed in TOSA.

Also expand the set of cases that are considering non-narrowing by only checking if the cast we're trying to remove is non-narrowing. Example i16 -> i32 -> i8 would have been rejected before, but it is now safely converted to a single i16 -> i8 tosa.cast, since the behaviour should identical for the entire input space.

Finally disallow the optimization in the case when the cast that we would remove involves integer types of different signedness.


Change-Id: Iad742406f663d0f1f59c511705f84e2e65d4e370

>From 44327fc3032ceb183b9e6675d2495325bdc1d059 Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Thu, 9 Apr 2026 10:57:50 +0100
Subject: [PATCH] [mlir][tosa] Optimize non-narrowing float casts

Extend the existing NonNarrowingCastsOptimization to also cover casts
between floating point types f32, f16, bf16, f8E4M3FN and F8E5M2. Avoid
introducing direct casts between f8 types since those are not allowed in
TOSA.

Also expand the set of cases that are considering non-narrowing by only
checking if the cast we're trying to remove is non-narrowing. Example
i16 -> i32 -> i8 would have been rejected before, but it is now safely
converted to a single i16 -> i8 tosa.cast, since the behaviour should
identical for the entire input space.

Finally disallow the optimization in the case when the cast that we
would remove involves integer types of different signedness.

Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: Iad742406f663d0f1f59c511705f84e2e65d4e370
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  97 ++++++++++++++--
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 109 ++++++++++++++++++
 2 files changed, 195 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ecd485ae8d641..573225056cb9b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -914,28 +914,103 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
 
     const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
                                               outerOutputType};
+
     if (llvm::any_of(types, [](const ShapedType type) {
-          return !type.getElementType().isInteger();
+          const auto elemTy = type.getElementType();
+          // Support a specific set of floating point types since we need to be
+          // careful in not introducing unsupported type combinations
+          return !(elemTy.isInteger() ||
+                   llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
+                             Float16Type, Float32Type>(elemTy));
         }))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "only integer types are supported");
+      return rewriter.notifyMatchFailure(
+          castOp,
+          "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are supported");
 
-    // Check inner cast is non-narrowing
-    const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
-    if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
-      return rewriter.notifyMatchFailure(castOp,
-                                         "inner cast operation is narrowing");
+    if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
+        llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
+                  "legal in TOSA");
+    }
 
-    // Check outer cast is non-narrowing from the inner cast input
-    if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+    if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
+        llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
+                  "legal in TOSA");
+    }
+
+    // Check that the cast we're considering for removal is non-narrowing
+    if (isNarrowingCast(innerInputType, innerOutputType))
       return rewriter.notifyMatchFailure(castOp,
-                                         "outer cast operation is narrowing");
+                                         "inner cast operation is narrowing");
 
     rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
                                               innerCastInput);
 
     return success();
   }
+
+  bool supportsNaN(const llvm::fltSemantics &semantics) const {
+    return semantics.nonFiniteBehavior !=
+           llvm::fltNonfiniteBehavior::FiniteOnly;
+  }
+
+  bool supportsInf(const llvm::fltSemantics &semantics) const {
+    return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
+  }
+
+  bool isNarrowingCast(const ShapedType inType,
+                       const ShapedType outType) const {
+
+    if (inType.getElementType().isInteger() &&
+        outType.getElementType().isInteger()) {
+
+      const auto inTypeSignedness =
+          cast<IntegerType>(inType.getElementType()).getSignedness();
+      const auto outTypeSignedness =
+          cast<IntegerType>(outType.getElementType()).getSignedness();
+
+      return (inTypeSignedness != outTypeSignedness ||
+              inType.getElementTypeBitWidth() >
+                  outType.getElementTypeBitWidth());
+    }
+
+    if (inType.getElementType().isFloat() &&
+        outType.getElementType().isFloat()) {
+
+      FloatType inElemTy = cast<FloatType>(inType.getElementType());
+      FloatType outElemTy = cast<FloatType>(outType.getElementType());
+      llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
+      llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
+
+      // If the list of supported types needs to be updated in the future, the
+      // check down below will need to be revised, for example to account for
+      // unsigned floating point types, or types that use negative zero as the
+      // representation for NaN.
+      [[maybe_unused]] const auto isSupported = [](Type elemType) {
+        return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
+                         Float16Type, Float32Type>(elemType);
+      };
+
+      assert(isSupported(inElemTy) &&
+             "unsupported input element type in isNarrowingCast");
+      assert(isSupported(outElemTy) &&
+             "unsupported output element type in isNarrowingCast");
+
+      return (
+          inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
+          inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
+          inTypeSemantics.precision > outTypeSemantics.precision ||
+          (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
+          (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
+    }
+
+    // While some cases of int -> float casts can be non-narrowing, consider
+    // them narrowing for the purposes of this optimization
+    return true;
+  }
 };
 
 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 367a60c4d2a8d..9a7fa3efc8d3c 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1487,6 +1487,115 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
 
 // -----
 
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i8_to_ui16_to_i8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i8_to_ui16_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xui16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xui16>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f16_to_f8
+// CHECK: return %arg0
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f16_to_f8(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E4M3FN>
+  return %1 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f8E4M3FN_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>)
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f8E4M3FN_to_f16(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f32_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f32_to_f16(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf8E5M2>
+func.func @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf8E5M2>
+  return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+  return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E4M3
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E4M3(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E4M3>
+  return %1 : tensor<13x21x3xf8E4M3>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3_to_f8E4M3FN_to_f16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3_to_f8E4M3FN_to_f16(%arg0: tensor<13x21x3xf8E4M3>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3>) -> tensor<13x21x3xf8E4M3FN>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FNUZ_to_f16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FNUZ_to_f16(%arg0: tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf8E4M3FNUZ>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsupported(%arg0: tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf8E4M3FN>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1
 // CHECK: return %arg0, %arg1 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
 func.func @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1(%arg0: tensor<15x3x2x256xf4E2M1FN>, %arg1: tensor<15x3x2x8xf8E8M0FNU>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) {



More information about the Mlir-commits mailing list