[Mlir-commits] [mlir] [mlir][tosa] Avoid introducing int <-> float casts (PR #195882)

Ian Tayler Lessa llvmlistbot at llvm.org
Wed May 6 09:45:28 PDT 2026


https://github.com/IanTaylerLessa-arm updated https://github.com/llvm/llvm-project/pull/195882

>From 79b7b1c91305fcae4db80367f3d1f9e04fb8b2db Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Tue, 5 May 2026 16:58:45 +0100
Subject: [PATCH 1/3] [mlir][tosa] Avoid introducing int <-> float casts

As part of the NonNarrowingCastsOptimization we were optimizing away
some cases where the inner input was an integer and the outer output was
a float.

Not all of the resulting dtype combinations for these cases are
supported by TOSA, so these scenarios are no longer optimized as part of
canonicalizations.

Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: I4adfe86d4a9f19427fc6425e687de299ebfe9b1f
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp |  7 +++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir           | 12 ++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..33f5633fc534d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -941,6 +941,13 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
                   "legal in TOSA");
     }
 
+    if (innerInputType.getElementType().isInteger() !=
+        outerOutputType.getElementType().isInteger()) {
+      return rewriter.notifyMatchFailure(
+          castOp, "integer to float and float to integer casts are not "
+                  "supported to avoid introducing illegal type combinations");
+    }
+
     // Check that the cast we're considering for removal is non-narrowing
     if (isNarrowingCast(innerInputType, innerOutputType))
       return rewriter.notifyMatchFailure(castOp,
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..5ea1b0f73444d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1594,6 +1594,18 @@ func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsu
   return %1 : tensor<13x21x3xf16>
 }
 
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf32>
+  return %1 : tensor<13x21x3xf32>
+}
+
+
 // -----
 
 // CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1

>From 5c079002e838e5ab26d4045058dadfa852af13cc Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Wed, 6 May 2026 10:26:18 +0100
Subject: [PATCH 2/3] [mlir][tosa] Add more type checks to optimization

Adds more type checks to NonNarrowingCastsOptimization, restricting
cases like i8 -> i64, i16 -> i64 and bf16 <-> fp16.

Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: I41a2809ab663b741d2bb761bd6050b0d71ab8a8a
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 85 +++++++++++++++----
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 43 ++++++++++
 2 files changed, 111 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 33f5633fc534d..2b558d8d80a59 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -907,45 +907,96 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
 
     const Value innerCastInput = innerCastOp.getInput();
 
-    const auto innerInputType =
+    const ShapedType innerInputType =
         llvm::cast<ShapedType>(innerCastInput.getType());
-    const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
-    const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+    const ShapedType innerOutputType =
+        llvm::cast<ShapedType>(innerCastOp.getType());
+    const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
 
-    const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
-                                              outerOutputType};
+    const Type innerInputElemType = innerInputType.getElementType();
+    const Type innerOutputElemType = innerOutputType.getElementType();
+    const Type outerOutputElemType = outerOutputType.getElementType();
 
-    if (llvm::any_of(types, [](const ShapedType type) {
-          const auto elemTy = type.getElementType();
+    const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
+                                        outerOutputElemType};
+
+    if (llvm::any_of(types, [](const Type type) {
           // Support a specific set of floating point types since we need to be
           // careful in not introducing unsupported type combinations
-          return !(elemTy.isInteger() ||
+          return !(type.isInteger() ||
                    llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
-                             Float16Type, Float32Type>(elemTy));
+                             Float16Type, Float32Type>(type));
         }))
       return rewriter.notifyMatchFailure(
           castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
                   "supported");
 
-    if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
-        llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+    if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
+        llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
                   "legal in TOSA");
     }
 
-    if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
-        llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+    if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
+        llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
                   "legal in TOSA");
     }
 
-    if (innerInputType.getElementType().isInteger() !=
-        outerOutputType.getElementType().isInteger()) {
+    if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
+        outerOutputElemType.isInteger()) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing fp8 -> integer casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (innerInputElemType.isInteger() &&
+        llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing integer -> fp8 casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (llvm::isa<Float16Type>(innerInputElemType) &&
+        llvm::isa<BFloat16Type>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing fp16 -> bf16 casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (llvm::isa<BFloat16Type>(innerInputElemType) &&
+        llvm::isa<Float16Type>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing bf16 -> fp16 casts which are not "
+                  "legal in TOSA");
+    }
+
+    const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
+                                   size_t bitwidth2) {
+      return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
+    };
+
+    if (isIntegerOneOf(innerInputElemType, 8, 16) &&
+        outerOutputElemType.isInteger(64)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing i8/i16 -> i64 casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (isIntegerOneOf(innerInputElemType, 1, 64) &&
+        !outerOutputElemType.isInteger()) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing bool/i64 to float casts which are not "
+                  "supported in all versions of TOSA");
+    }
+
+    if (!innerInputElemType.isInteger() &&
+        isIntegerOneOf(outerOutputElemType, 1, 64)) {
       return rewriter.notifyMatchFailure(
-          castOp, "integer to float and float to integer casts are not "
-                  "supported to avoid introducing illegal type combinations");
+          castOp, "avoid introducing float to bool/i64 casts which are not "
+                  "supported in all versions of TOSA");
     }
 
     // Check that the cast we're considering for removal is non-narrowing
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5ea1b0f73444d..11f1cb906b40e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1541,6 +1541,17 @@ func.func @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8(%arg0: tensor<1
 
 // -----
 
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf16>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32_to_f16(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf16>
+  return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
 // CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2
 // CHECK: tosa.cast
 // CHECK: tosa.cast
@@ -1605,6 +1616,38 @@ func.func @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported(%arg0: ten
   return %1 : tensor<13x21x3xf32>
 }
 
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i64_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i64_unsupported(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi64> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
+  return %1 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f16_to_bf16_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f16_to_bf16_unsupported(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xbf16>
+  return %1 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_f8E4M3FN_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i8_to_f8E4M3FN_unsupported(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xf8E4M3FN> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf32>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf8E4M3FN>
+  return %1 : tensor<13x21x3xf8E4M3FN>
+}
 
 // -----
 

>From 4b79af0e034859c12701ef046d0f9213c5ec3d5e Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Wed, 6 May 2026 14:56:02 +0100
Subject: [PATCH 3/3] [mlir][tosa] Rename isIntegerOneOfWidth

Previously called isIntegerOneOf

Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: Icc60ee86c2f09f1653723b54c419c8fcb623d89b
---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 2b558d8d80a59..c22651eb3fb02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -973,19 +973,19 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
                   "legal in TOSA");
     }
 
-    const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
-                                   size_t bitwidth2) {
+    const auto isIntegerOneOfWidth = [](Type type, size_t bitwidth1,
+                                        size_t bitwidth2) {
       return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
     };
 
-    if (isIntegerOneOf(innerInputElemType, 8, 16) &&
+    if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
         outerOutputElemType.isInteger(64)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing i8/i16 -> i64 casts which are not "
                   "legal in TOSA");
     }
 
-    if (isIntegerOneOf(innerInputElemType, 1, 64) &&
+    if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
         !outerOutputElemType.isInteger()) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing bool/i64 to float casts which are not "
@@ -993,7 +993,7 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
     }
 
     if (!innerInputElemType.isInteger() &&
-        isIntegerOneOf(outerOutputElemType, 1, 64)) {
+        isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing float to bool/i64 casts which are not "
                   "supported in all versions of TOSA");



More information about the Mlir-commits mailing list