[Mlir-commits] [mlir] [mlir][tosa] Avoid introducing int <-> float casts (PR #195882)
Ian Tayler Lessa
llvmlistbot at llvm.org
Wed May 6 09:45:28 PDT 2026
https://github.com/IanTaylerLessa-arm updated https://github.com/llvm/llvm-project/pull/195882
>From 79b7b1c91305fcae4db80367f3d1f9e04fb8b2db Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Tue, 5 May 2026 16:58:45 +0100
Subject: [PATCH 1/3] [mlir][tosa] Avoid introducing int <-> float casts
As part of the NonNarrowingCastsOptimization we were optimizing away
some cases where the inner input was an integer and the outer output was
a float.
Not all of the resulting dtype combinations for these cases are
supported by TOSA, so these scenarios are no longer optimized as part of
canonicalizations.
Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: I4adfe86d4a9f19427fc6425e687de299ebfe9b1f
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 7 +++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 12 ++++++++++++
2 files changed, 19 insertions(+)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 1c186cd3ae122..33f5633fc534d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -941,6 +941,13 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
"legal in TOSA");
}
+ if (innerInputType.getElementType().isInteger() !=
+ outerOutputType.getElementType().isInteger()) {
+ return rewriter.notifyMatchFailure(
+ castOp, "integer to float and float to integer casts are not "
+ "supported to avoid introducing illegal type combinations");
+ }
+
// Check that the cast we're considering for removal is non-narrowing
if (isNarrowingCast(innerInputType, innerOutputType))
return rewriter.notifyMatchFailure(castOp,
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 9a7fa3efc8d3c..5ea1b0f73444d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1594,6 +1594,18 @@ func.func @test_canonicalize_non_narrowing_cast_f6E3M2FN_to_f8E4M3FN_to_f16_unsu
return %1 : tensor<13x21x3xf16>
}
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf32>
+ return %1 : tensor<13x21x3xf32>
+}
+
+
// -----
// CHECK-LABEL: @test_canonicalize_cast_from_cast_to_block_scaled_f4E2M1
>From 5c079002e838e5ab26d4045058dadfa852af13cc Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Wed, 6 May 2026 10:26:18 +0100
Subject: [PATCH 2/3] [mlir][tosa] Add more type checks to optimization
Adds more type checks to NonNarrowingCastsOptimization, restricting
cases like i8 -> i64, i16 -> i64 and bf16 <-> fp16.
Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: I41a2809ab663b741d2bb761bd6050b0d71ab8a8a
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 85 +++++++++++++++----
mlir/test/Dialect/Tosa/canonicalize.mlir | 43 ++++++++++
2 files changed, 111 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 33f5633fc534d..2b558d8d80a59 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -907,45 +907,96 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
const Value innerCastInput = innerCastOp.getInput();
- const auto innerInputType =
+ const ShapedType innerInputType =
llvm::cast<ShapedType>(innerCastInput.getType());
- const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
- const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+ const ShapedType innerOutputType =
+ llvm::cast<ShapedType>(innerCastOp.getType());
+ const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
- const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
- outerOutputType};
+ const Type innerInputElemType = innerInputType.getElementType();
+ const Type innerOutputElemType = innerOutputType.getElementType();
+ const Type outerOutputElemType = outerOutputType.getElementType();
- if (llvm::any_of(types, [](const ShapedType type) {
- const auto elemTy = type.getElementType();
+ const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
+ outerOutputElemType};
+
+ if (llvm::any_of(types, [](const Type type) {
// Support a specific set of floating point types since we need to be
// careful in not introducing unsupported type combinations
- return !(elemTy.isInteger() ||
+ return !(type.isInteger() ||
llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
- Float16Type, Float32Type>(elemTy));
+ Float16Type, Float32Type>(type));
}))
return rewriter.notifyMatchFailure(
castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
"supported");
- if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
- llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+ if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
+ llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
"legal in TOSA");
}
- if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
- llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+ if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
+ llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
"legal in TOSA");
}
- if (innerInputType.getElementType().isInteger() !=
- outerOutputType.getElementType().isInteger()) {
+ if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
+ outerOutputElemType.isInteger()) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing fp8 -> integer casts which are not "
+ "legal in TOSA");
+ }
+
+ if (innerInputElemType.isInteger() &&
+ llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing integer -> fp8 casts which are not "
+ "legal in TOSA");
+ }
+
+ if (llvm::isa<Float16Type>(innerInputElemType) &&
+ llvm::isa<BFloat16Type>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing fp16 -> bf16 casts which are not "
+ "legal in TOSA");
+ }
+
+ if (llvm::isa<BFloat16Type>(innerInputElemType) &&
+ llvm::isa<Float16Type>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing bf16 -> fp16 casts which are not "
+ "legal in TOSA");
+ }
+
+ const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
+ size_t bitwidth2) {
+ return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
+ };
+
+ if (isIntegerOneOf(innerInputElemType, 8, 16) &&
+ outerOutputElemType.isInteger(64)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing i8/i16 -> i64 casts which are not "
+ "legal in TOSA");
+ }
+
+ if (isIntegerOneOf(innerInputElemType, 1, 64) &&
+ !outerOutputElemType.isInteger()) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing bool/i64 to float casts which are not "
+ "supported in all versions of TOSA");
+ }
+
+ if (!innerInputElemType.isInteger() &&
+ isIntegerOneOf(outerOutputElemType, 1, 64)) {
return rewriter.notifyMatchFailure(
- castOp, "integer to float and float to integer casts are not "
- "supported to avoid introducing illegal type combinations");
+ castOp, "avoid introducing float to bool/i64 casts which are not "
+ "supported in all versions of TOSA");
}
// Check that the cast we're considering for removal is non-narrowing
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 5ea1b0f73444d..11f1cb906b40e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1541,6 +1541,17 @@ func.func @test_canonicalize_non_narrowing_cast_f16_to_f32_to_f8(%arg0: tensor<1
// -----
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32_to_f16
+// CHECK: %[[OUT:.+]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf16>
+// CHECK: return %[[OUT]] : tensor<13x21x3xf16>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32_to_f16(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf16>
+ return %1 : tensor<13x21x3xf16>
+}
+
+// -----
+
// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f8E4M3FN_to_f16_to_f8E5M2
// CHECK: tosa.cast
// CHECK: tosa.cast
@@ -1605,6 +1616,38 @@ func.func @test_canonicalize_non_narrowing_cast_i1_to_f32_unsupported(%arg0: ten
return %1 : tensor<13x21x3xf32>
}
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i64_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i64_unsupported(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi64> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
+ return %1 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f16_to_bf16_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f16_to_bf16_unsupported(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xbf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf32>
+ %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xbf16>
+ return %1 : tensor<13x21x3xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_f8E4M3FN_unsupported
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_i8_to_f8E4M3FN_unsupported(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xf8E4M3FN> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xf32>
+ %1 = tosa.cast %0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf8E4M3FN>
+ return %1 : tensor<13x21x3xf8E4M3FN>
+}
// -----
>From 4b79af0e034859c12701ef046d0f9213c5ec3d5e Mon Sep 17 00:00:00 2001
From: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Date: Wed, 6 May 2026 14:56:02 +0100
Subject: [PATCH 3/3] [mlir][tosa] Rename isIntegerOneOfWidth
Previously called isIntegerOneOf
Signed-off-by: Ian Tayler Lessa <ian.taylerlessa at arm.com>
Change-Id: Icc60ee86c2f09f1653723b54c419c8fcb623d89b
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 2b558d8d80a59..c22651eb3fb02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -973,19 +973,19 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
"legal in TOSA");
}
- const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
- size_t bitwidth2) {
+ const auto isIntegerOneOfWidth = [](Type type, size_t bitwidth1,
+ size_t bitwidth2) {
return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
};
- if (isIntegerOneOf(innerInputElemType, 8, 16) &&
+ if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
outerOutputElemType.isInteger(64)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing i8/i16 -> i64 casts which are not "
"legal in TOSA");
}
- if (isIntegerOneOf(innerInputElemType, 1, 64) &&
+ if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
!outerOutputElemType.isInteger()) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing bool/i64 to float casts which are not "
@@ -993,7 +993,7 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
}
if (!innerInputElemType.isInteger() &&
- isIntegerOneOf(outerOutputElemType, 1, 64)) {
+ isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing float to bool/i64 casts which are not "
"supported in all versions of TOSA");
More information about the Mlir-commits
mailing list