[Mlir-commits] [mlir] [mlir][tosa] Avoid introducing int <-> float casts (PR #195882)

Luke Hutton llvmlistbot at llvm.org
Wed May 6 07:52:50 PDT 2026


================
@@ -907,40 +907,98 @@ struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
 
     const Value innerCastInput = innerCastOp.getInput();
 
-    const auto innerInputType =
+    const ShapedType innerInputType =
         llvm::cast<ShapedType>(innerCastInput.getType());
-    const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
-    const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+    const ShapedType innerOutputType =
+        llvm::cast<ShapedType>(innerCastOp.getType());
+    const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
 
-    const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
-                                              outerOutputType};
+    const Type innerInputElemType = innerInputType.getElementType();
+    const Type innerOutputElemType = innerOutputType.getElementType();
+    const Type outerOutputElemType = outerOutputType.getElementType();
 
-    if (llvm::any_of(types, [](const ShapedType type) {
-          const auto elemTy = type.getElementType();
+    const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
+                                        outerOutputElemType};
+
+    if (llvm::any_of(types, [](const Type type) {
           // Support a specific set of floating point types since we need to be
           // careful in not introducing unsupported type combinations
-          return !(elemTy.isInteger() ||
+          return !(type.isInteger() ||
                    llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
-                             Float16Type, Float32Type>(elemTy));
+                             Float16Type, Float32Type>(type));
         }))
       return rewriter.notifyMatchFailure(
           castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
                   "supported");
 
-    if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
-        llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
+    if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
+        llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
                   "legal in TOSA");
     }
 
-    if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
-        llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
+    if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
+        llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
       return rewriter.notifyMatchFailure(
           castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
                   "legal in TOSA");
     }
 
+    if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
+        outerOutputElemType.isInteger()) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing fp8 -> integer casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (innerInputElemType.isInteger() &&
+        llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing integer -> fp8 casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (llvm::isa<Float16Type>(innerInputElemType) &&
+        llvm::isa<BFloat16Type>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing fp16 -> bf16 casts which are not "
+                  "legal in TOSA");
+    }
+
+    if (llvm::isa<BFloat16Type>(innerInputElemType) &&
+        llvm::isa<Float16Type>(outerOutputElemType)) {
+      return rewriter.notifyMatchFailure(
+          castOp, "avoid introducing bf16 -> fp16 casts which are not "
+                  "legal in TOSA");
+    }
+
+    const auto isIntegerOneOf = [](Type type, size_t bitwidth1,
----------------
lhutton1 wrote:

Thanks and agreed, happy for this to be something we improve upon in the future

https://github.com/llvm/llvm-project/pull/195882


More information about the Mlir-commits mailing list