[Mlir-commits] [mlir] [mlir][tosa] Optimize non-narrowing float casts (PR #191439)

Ian Tayler Lessa llvmlistbot at llvm.org
Fri Apr 10 08:26:47 PDT 2026


https://github.com/IanTaylerLessa-arm updated https://github.com/llvm/llvm-project/pull/191439

>From ea3a5d798595da0b1ffc7dd1cee51ace0cf2bdcb Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Thu, 9 Apr 2026 10:57:50 +0100
Subject: [PATCH] [mlir][tosa] Optimize non-narrowing float casts

Extend the existing NonNarrowingCastsOptimization to also cover casts
between floating point types f32, f16, bf16, f8E4M3FN and F8E5M2. Avoid
introducing direct casts between f8 types since those are not allowed in
TOSA.

Also expand the set of cases that are considering non-narrowing by only
checking if the cast we're trying to remove is non-narrowing. Example
i16 -> i32 -> i8 would have been rejected before, but it is now safely
converted to a single i16 -> i8 tosa.cast, since the behaviour should
identical for the entire input space.

Finally disallow the optimization in the case when the cast that we
would remove involves integer types of different signedness.

Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: Iad742406f663d0f1f59c511705f84e2e65d4e370
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  97 ++++++++++++++--
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 109 ++++++++++++++++++
 2 files changed, 195 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index ecd485ae8d641..1c186cd3ae122 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -914,28 +914,103 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
 
     const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
                                               outerOutputType};
+
     if (llvm::any_of(types, [](const ShapedType type) {
-          return !type.getElementType().isInteger();
+          const auto elemTy = type.getElementType();
+          // Support a specific set of floating point types since we need to be
+          // careful in not introducing unsupported type combinations
+          return !(elemTy.isInteger() ||
+                   llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
+                             Float16Type, Float32Type>(elemTy));
         }))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "only integer types are supported");
+      return rewriter.notifyMatchFailure(
+          castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
+                  "supported");
 
-    // Check inner cast is non-narrowing
-    const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
-    if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
-      return rewriter.notifyMatchFailure(castOp,
-                                         "inner cast operation is narrowing");
+    if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
+        llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
+                  "legal in TOSA");
+    }
 
-    // Check outer cast is non-narrowing from the inner cast input
-    if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+    if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
+        llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
+                  "legal in TOSA");
+    }
+
+    // Check that the cast we're considering for removal is non-narrowing
+    if (isNarrowingCast(innerInputType, innerOutputType))
       return rewriter.notifyMatchFailure(castOp,
-                                         "outer cast operation is narrowing");
+                                         "inner cast operation is narrowing");
 
     rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
                                               innerCastInput);
 
     return success();
   }
+
+  bool supportsNaN(const llvm::fltSemantics &semantics) const {
+    return semantics.nonFiniteBehavior !=
+           llvm::fltNonfiniteBehavior::FiniteOnly;
+  }
+
+  bool supportsInf(const llvm::fltSemantics &semantics) const {
+    return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
+  }
+
+  bool isNarrowingCast(const ShapedType inType,
+                       const ShapedType outType) const {
+
+    if (inType.getElementType().isInteger() &&
+        outType.getElementType().isInteger()) {
+
+      const auto inTypeSignedness =
+          cast<IntegerType>(inType.getElementType()).getSignedness();
+      const auto outTypeSignedness =
+          cast<IntegerType>(outType.getElementType()).getSignedness();
+
+      return (inTypeSignedness != outTypeSignedness ||
+              inType.getElementTypeBitWidth() >
+                  outType.getElementTypeBitWidth());
+    }
+
+    if (inType.getElementType().isFloat() &&
+        outType.getElementType().isFloat()) {
+
+      FloatType inElemTy = cast<FloatType>(inType.getElementType());
+      FloatType outElemTy = cast<FloatType>(outType.getElementType());
+      llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
+      llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
+
+      // If the list of supported types needs to be updated in the future, the
+      // check down below will need to be revised, for example to account for
+      // unsigned floating point types, or types that use negative zero as the
+      // representation for NaN.
+      [[maybe_unused]] const auto isSupported = [](Type elemType) {
+        return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
+                         Float16Type, Float32Type>(elemType);
+      };
+
+      assert(isSupported(inElemTy) &&
+             "unsupported input element type in isNarrowingCast");
+      assert(isSupported(outElemTy) &&
+             "unsupported output element type in isNarrowingCast");
+
+      return (
+          inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
+          inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
+          inTypeSemantics.precision > outTypeSemantics.precision ||
+          (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
+          (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
+    }
+
+    // While some cases of int -> float casts can be non-narrowing, consider
+    // them narrowing for the purposes of this optimization
+    return true;
+  }
 };
 
 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 367a60c4d2a8d..9a7fa3efc8d3c 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1487,6 +1487,115 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
 
 // -----
 
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i8_to_ui16_to_i8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i8_to_ui16_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xui16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xui16>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f16_to_f8
+// CHECK: return %arg0
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f16_to_f8(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E4M3FN>
+  return %1 : tensor<13x21x3xf8E4M3FN>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f8E4M3FN_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>)
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f8E4M3FN_to_f16(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8_to_f32_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_f8_to_f32_to_f16(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf8E5M2>
+func.func @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf8E5M2>
+  return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+  return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E4M3
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E4M3(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E4M3>
+  return %1 : tensor<13x21x3xf8E4M3>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3_to_f8E4M3FN_to_f16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f8E4M3_to_f8E4M3FN_to_f16(%arg0: tensor<13x21x3xf8E4M3>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf8E4M3>) -> tensor<13x21x3xf8E4M3FN>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FNUZ_to_f16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FNUZ_to_f16(%arg0: tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf8E4M3FNUZ>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsupported(%arg0: tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf6E3M2FN>) -> tensor<13x21x3xf8E4M3FN>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1
 // CHECK: return %arg0, %arg1 : tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>
 func.func @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1(%arg0: tensor<15x3x2x256xf4E2M1FN>, %arg1: tensor<15x3x2x8xf8E8M0FNU>) -> (tensor<15x3x2x256xf4E2M1FN>, tensor<15x3x2x8xf8E8M0FNU>) {



More information about the Mlir-commits mailing list