[Mlir-commits] [mlir] [mlir][tosa] Avoid introducing int <-> float casts (PR #195882)
Luke Hutton
llvmlistbot at llvm.org
Wed May 6 07:52:50 PDT 2026
================
@@ -907,40 +907,98 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
const Value innerCastInput = innerCastOp.getInput();
- const auto innerInputType =
+ const ShapedType innerInputType =
llvm::cast<ShapedType>(innerCastInput.getType());
- const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
- const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+ const ShapedType innerOutputType =
+ llvm::cast<ShapedType>(innerCastOp.getType());
+ const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
- const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
- outerOutputType};
+ const Type innerInputElemType = innerInputType.getElementType();
+ const Type innerOutputElemType = innerOutputType.getElementType();
+ const Type outerOutputElemType = outerOutputType.getElementType();
- if (llvm::any_of(types, [](const ShapedType type) {
- const auto elemTy = type.getElementType();
+ const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
+ outerOutputElemType};
+
+ if (llvm::any_of(types, [](const Type type) {
// Support a specific set of floating point types since we need to be
// careful in not introducing unsupported type combinations
- return !(elemTy.isInteger() ||
+ return !(type.isInteger() ||
llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
- Float16Type, Float32Type>(elemTy));
+ Float16Type, Float32Type>(type));
}))
return rewriter.notifyMatchFailure(
castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
"supported");
- if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
- llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+ if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
+ llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
"legal in TOSA");
}
- if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
- llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+ if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
+ llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
return rewriter.notifyMatchFailure(
castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
"legal in TOSA");
}
+ if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
+ outerOutputElemType.isInteger()) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing fp8 -> integer casts which are not "
+ "legal in TOSA");
+ }
+
+ if (innerInputElemType.isInteger() &&
+ llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing integer -> fp8 casts which are not "
+ "legal in TOSA");
+ }
+
+ if (llvm::isa<Float16Type>(innerInputElemType) &&
+ llvm::isa<BFloat16Type>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing fp16 -> bf16 casts which are not "
+ "legal in TOSA");
+ }
+
+ if (llvm::isa<BFloat16Type>(innerInputElemType) &&
+ llvm::isa<Float16Type>(outerOutputElemType)) {
+ return rewriter.notifyMatchFailure(
+ castOp, "avoid introducing bf16 -> fp16 casts which are not "
+ "legal in TOSA");
+ }
+
+ const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
----------------
lhutton1 wrote:
Thanks and agreed, happy for this to be something we improve upon in the future
https://github.com/llvm/llvm-project/pull/195882
More information about the Mlir-commits
mailing list