[Mlir-commits] [mlir] [mlir][tosa] Extend narrowing pass (PR #170712)

Vitalii Shutov llvmlistbot at llvm.org
Mon Dec 15 02:53:54 PST 2025


https://github.com/Lallapallooza updated https://github.com/llvm/llvm-project/pull/170712

>From 02ae69e292333879fdc0d2ef744c6e3e5000cd7b Mon Sep 17 00:00:00 2001
From: Vitalii Shutov <vitalii.shutov at arm.com>
Date: Thu, 4 Dec 2025 17:38:50 +0000
Subject: [PATCH 1/3] [mlir][tosa] Extend narrowing pass

- unify the i64->i32 and f64->f32 narrowing logic inside the shared
  implementation
- register tosa::ConstOp in the non-aggressive rewrite set so
  standalone constants are narrowed

Signed-off-by: Vitalii Shutov <vitalii.shutov at arm.com>
Change-Id: I3bb17e1d514121a08ac4716c6c991ad7e87b4c17
---
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  25 +
 .../Dialect/Tosa/Transforms/CMakeLists.txt    |   2 +-
 .../Tosa/Transforms/TosaNarrowI64ToI32.cpp    | 310 ----------
 .../Tosa/Transforms/TosaNarrowTypes.cpp       | 558 ++++++++++++++++++
 .../tosa-narrow-f64-to-f32-aggressive.mlir    |  70 +++
 .../Dialect/Tosa/tosa-narrow-f64-to-f32.mlir  | 160 +++++
 .../Dialect/Tosa/tosa-narrow-i64-to-i32.mlir  |  12 +
 7 files changed, 826 insertions(+), 311 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
 create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
 create mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
 create mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 12f520297b702..4a5f283bc66c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -198,4 +198,29 @@ def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
   ];
 }
 
+def TosaNarrowF64ToF32Pass : Pass<"tosa-narrow-f64-to-f32", "func::FuncOp"> {
+  let summary = "Narrow F64 TOSA operations to F32";
+  let description = [{
+    This pass narrows TOSA operations with 64-bit floating-point tensor types to
+    32-bit floating-point tensor types. While TOSA itself has no double
+    precision support, upstream conversions or frontends may still materialize
+    F64 tensors temporarily, so this pass removes them before handing off to a
+    TOSA backend.
+  }];
+
+  let options = [
+    Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
+      "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
+      "is safe. This option may lead to data loss if not used carefully.">,
+    Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
+      "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+      "be inserted at the I/O boundaries.">
+  ];
+
+  let dependentDialects = [
+    "func::FuncDialect",
+    "tosa::TosaDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 091b481d6394b..0ff68b1bb54f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaTypeConverters.cpp
   TosaProfileCompliance.cpp
   TosaValidation.cpp
-  TosaNarrowI64ToI32.cpp
+  TosaNarrowTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
deleted file mode 100644
index ddaf7d8a5e033..0000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
+++ /dev/null
@@ -1,310 +0,0 @@
-//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass narrows TOSA operations with 64-bit integer tensor types to
-// 32-bit integer tensor types. This can be useful for backends that do not
-// support the EXT-INT64 extension of TOSA. The pass has two options:
-//
-// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
-//     regardless or whether the narrowing is safe. This option may lead to
-//     data loss if not used carefully.
-// - convert-function-boundaries - If enabled, the pass will convert function
-//     I/O types as well. Otherwise casts will be inserted at the I/O
-//     boundaries.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/IR/Verifier.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace tosa {
-#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
-#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
-} // namespace tosa
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-LogicalResult convertGenericOp(Operation *op, ValueRange operands,
-                               ConversionPatternRewriter &rewriter,
-                               const TypeConverter *typeConverter) {
-  // Convert types of results
-  SmallVector<Type, 4> newResults;
-  if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
-    return failure();
-
-  // Create a new operation state
-  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                       newResults, {}, op->getSuccessors());
-
-  for (const NamedAttribute &namedAttribute : op->getAttrs()) {
-    const Attribute attribute = namedAttribute.getValue();
-
-    // Convert integer attribute type
-    if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
-      Type type = typeAttr.getValue();
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(type, attribute);
-      if (!convertedAttribute)
-        return rewriter.notifyMatchFailure(op,
-                                           "Failed to convert type attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
-      const Type type = denseElementsAttr.getType();
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(type, denseElementsAttr);
-      if (!convertedAttribute)
-        return rewriter.notifyMatchFailure(
-            op, "Failed to convert dense elements attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    state.addAttribute(namedAttribute.getName(), attribute);
-  }
-
-  for (Region &region : op->getRegions()) {
-    Region *newRegion = state.addRegion();
-    rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-    if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
-      return failure();
-  }
-
-  Operation *newOp = rewriter.create(state);
-  rewriter.replaceOp(op, newOp->getResults());
-  return success();
-}
-
-// ===========================
-// Aggressive rewrite patterns
-// ===========================
-
-class ConvertGenericOp : public ConversionPattern {
-public:
-  ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
-      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    if (!isa<tosa::TosaOp>(op))
-      return rewriter.notifyMatchFailure(
-          op,
-          "Support for operations other than TOSA has not been implemented.");
-
-    return convertGenericOp(op, operands, rewriter, typeConverter);
-  }
-};
-
-// ===============================
-// Bounds checked rewrite patterns
-// ===============================
-
-class ConvertArgMaxOpWithBoundsChecking
-    : public OpConversionPattern<tosa::ArgMaxOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    // Output type can be narrowed based on the size of the axis dimension
-    const int32_t axis = op.getAxis();
-    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
-    if (!inputType || !inputType.isStaticDim(axis))
-      return rewriter.notifyMatchFailure(
-          op, "Requires a static axis dimension for bounds checking.");
-    const int64_t axisDim = inputType.getDimSize(axis);
-    if (axisDim >= std::numeric_limits<int32_t>::max())
-      return rewriter.notifyMatchFailure(
-          op, "Axis dimension is too large to narrow safely.");
-
-    const Type resultType = op.getOutput().getType();
-    const Type newResultType = typeConverter->convertType(resultType);
-    rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
-                                                adaptor.getInput(), axis);
-    return success();
-  }
-};
-
-class ConvertCastOpWithBoundsChecking
-    : public OpConversionPattern<tosa::CastOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
-    const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
-    if (!inputType || !resultType)
-      return failure();
-
-    const auto elementInputIntType =
-        dyn_cast<IntegerType>(inputType.getElementType());
-    const auto elementResultIntType =
-        dyn_cast<IntegerType>(resultType.getElementType());
-    if (elementInputIntType && elementResultIntType &&
-        elementInputIntType.getWidth() > elementResultIntType.getWidth())
-      return rewriter.notifyMatchFailure(
-          op, "Narrowing cast may lead to data loss.");
-
-    rewriter.replaceOpWithNewOp<tosa::CastOp>(
-        op, typeConverter->convertType(resultType), adaptor.getInput());
-    return success();
-  }
-};
-
-template <typename OpTy>
-class ConvertTypedOp : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    return convertGenericOp(op, adaptor.getOperands(), rewriter,
-                            this->getTypeConverter());
-  }
-};
-
-struct TosaNarrowI64ToI32
-    : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
-public:
-  explicit TosaNarrowI64ToI32() = default;
-  explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
-      : TosaNarrowI64ToI32() {
-    this->aggressiveRewrite = options.aggressiveRewrite;
-    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
-  }
-
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-
-    TypeConverter typeConverter;
-    typeConverter.addConversion([](Type type) -> Type { return type; });
-    typeConverter.addConversion([](IntegerType type) -> Type {
-      if (!type.isInteger(64))
-        return type;
-      return IntegerType::get(type.getContext(), 32);
-    });
-    typeConverter.addConversion(
-        [&typeConverter](RankedTensorType type) -> Type {
-          const Type elementType = type.getElementType();
-          if (!elementType.isInteger(64))
-            return type;
-          return RankedTensorType::get(type.getShape(),
-                                       typeConverter.convertType(elementType));
-        });
-
-    const auto materializeCast = [](OpBuilder &builder, Type resultType,
-                                    ValueRange inputs, Location loc) -> Value {
-      if (inputs.size() != 1)
-        return Value();
-      return tosa::CastOp::create(builder, loc, resultType, inputs.front());
-    };
-    typeConverter.addSourceMaterialization(materializeCast);
-    typeConverter.addTargetMaterialization(materializeCast);
-
-    typeConverter.addTypeAttributeConversion(
-        [](IntegerType type, IntegerAttr attribute) -> Attribute {
-          const APInt value = attribute.getValue().truncSSat(32);
-          return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
-                                  value);
-        });
-    typeConverter.addTypeAttributeConversion(
-        [&typeConverter](ShapedType type,
-                         DenseIntElementsAttr attr) -> Attribute {
-          const ShapedType newType =
-              cast<ShapedType>(typeConverter.convertType(type));
-          const auto oldElementType = cast<IntegerType>(type.getElementType());
-          const auto newElementType =
-              cast<IntegerType>(newType.getElementType());
-          if (oldElementType.getWidth() == newElementType.getWidth())
-            return attr;
-
-          DenseElementsAttr mapped =
-              attr.mapValues(newElementType, [&](const APInt &v) {
-                return v.truncSSat(newElementType.getWidth());
-              });
-          return mapped;
-        });
-
-    ConversionTarget target(*context);
-    target.addDynamicallyLegalDialect<tosa::TosaDialect>(
-        [&typeConverter](Operation *op) {
-          return typeConverter.isLegal(op->getResultTypes()) &&
-                 typeConverter.isLegal(op->getOperandTypes());
-        });
-    if (convertFunctionBoundaries) {
-      target.addDynamicallyLegalOp<func::FuncOp>(
-          [&typeConverter](func::FuncOp op) {
-            return typeConverter.isSignatureLegal(op.getFunctionType()) &&
-                   typeConverter.isLegal(&op.getBody());
-          });
-      target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
-        const FunctionType funcType =
-            op->getParentOfType<func::FuncOp>().getFunctionType();
-        return llvm::equal(op.getOperandTypes(), funcType.getResults());
-      });
-    } else {
-      target.addDynamicallyLegalOp<func::FuncOp>(
-          [](func::FuncOp op) { return true; });
-      target.addDynamicallyLegalOp<func::ReturnOp>(
-          [](func::ReturnOp op) { return true; });
-    }
-
-    RewritePatternSet patterns(context);
-    if (convertFunctionBoundaries) {
-      populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-          patterns, typeConverter);
-      populateReturnOpTypeConversionPattern(patterns, typeConverter);
-    }
-    if (aggressiveRewrite) {
-      patterns.add<ConvertGenericOp>(typeConverter, context);
-    } else {
-      // Tensor
-      patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
-      // Data layout
-      patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
-      // Type conversion
-      patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
-      // Controlflow
-      patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
-    }
-
-    if (failed(
-            applyFullConversion(getOperation(), target, std::move(patterns))))
-      signalPassFailure();
-  }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
new file mode 100644
index 0000000000000..0895bb3905d67
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
@@ -0,0 +1,558 @@
+//===- TosaNarrowTypes.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TOSA narrowing passes that rewrite tensor element
+// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "llvm/ADT/APFloat.h"
+
+#include <limits>
+#include <type_traits>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// Narrowing mode for this pass.
+enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
+
+// ---------------------------------------------------------------------------
+// Shared helpers
+// ---------------------------------------------------------------------------
+
+template <TosaNarrowKind Kind>
+static bool isSourceInteger(IntegerType type) {
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+    return type.isInteger(64);
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceFloat(FloatType type) {
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+    return type.isF64();
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertInteger(IntegerType type) {
+  if (!isSourceInteger<Kind>(type))
+    return type;
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+    return IntegerType::get(type.getContext(), 32);
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertFloat(FloatType type) {
+  if (!isSourceFloat<Kind>(type))
+    return type;
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+    return Float32Type::get(type.getContext());
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceElement(Type type) {
+  if (auto intTy = dyn_cast<IntegerType>(type))
+    return isSourceInteger<Kind>(intTy);
+  if (auto floatTy = dyn_cast<FloatType>(type))
+    return isSourceFloat<Kind>(floatTy);
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertElement(Type type) {
+  if (auto intTy = dyn_cast<IntegerType>(type))
+    return convertInteger<Kind>(intTy);
+  if (auto floatTy = dyn_cast<FloatType>(type))
+    return convertFloat<Kind>(floatTy);
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool typeNeedsConversion(Type type) {
+  if (auto shaped = dyn_cast<ShapedType>(type))
+    return isSourceElement<Kind>(shaped.getElementType());
+  return isSourceElement<Kind>(type);
+}
+
+// Narrows scalar constant attributes so they keep matching the converted
+// element types.
+template <TosaNarrowKind Kind>
+static bool tryConvertScalarAttribute(Attribute attribute,
+                                      Attribute &resultAttr) {
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+    if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+      if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
+          intType && isSourceInteger<Kind>(intType)) {
+        const auto convertedType =
+            cast<IntegerType>(convertInteger<Kind>(intType));
+        const APInt truncated =
+            intAttr.getValue().truncSSat(convertedType.getWidth());
+        resultAttr = IntegerAttr::get(convertedType, truncated);
+        return true;
+      }
+    }
+  }
+
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+    if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
+      if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
+          floatType && isSourceFloat<Kind>(floatType)) {
+        const auto convertedType =
+            cast<FloatType>(convertFloat<Kind>(floatType));
+        APFloat value = floatAttr.getValue();
+        bool losesInfo = false;
+        value.convert(convertedType.getFloatSemantics(),
+                      APFloat::rmNearestTiesToEven, &losesInfo);
+        resultAttr = FloatAttr::get(convertedType, value);
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+template <TosaNarrowKind Kind, typename AttrT>
+static LogicalResult
+convertAttributeWithTypeConverter(AttrT attr, Type type,
+                                  const TypeConverter *typeConverter,
+                                  Attribute &resultAttr) {
+  if (!typeNeedsConversion<Kind>(type)) {
+    resultAttr = attr;
+    return success();
+  }
+
+  const std::optional<Attribute> convertedAttribute =
+      typeConverter->convertTypeAttribute(type, attr);
+  if (!convertedAttribute)
+    return failure();
+
+  resultAttr = convertedAttribute.value();
+  return success();
+}
+
+// Rejects cast rewrites that would lose precision (unless aggressive mode is
+// enabled).
+template <TosaNarrowKind Kind>
+static LogicalResult
+verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
+                               ShapedType resultType,
+                               ConversionPatternRewriter &rewriter) {
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+    const auto elementInputIntType =
+        dyn_cast<IntegerType>(inputType.getElementType());
+    const auto elementResultIntType =
+        dyn_cast<IntegerType>(resultType.getElementType());
+    if (elementInputIntType && elementResultIntType &&
+        elementInputIntType.getWidth() > elementResultIntType.getWidth())
+      return rewriter.notifyMatchFailure(
+          op, "Narrowing cast may lead to data loss.");
+  }
+
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+    const auto elementInputFloatType =
+        dyn_cast<FloatType>(inputType.getElementType());
+    const auto elementResultFloatType =
+        dyn_cast<FloatType>(resultType.getElementType());
+    if (elementInputFloatType && elementResultFloatType &&
+        elementInputFloatType.getIntOrFloatBitWidth() >
+            elementResultFloatType.getIntOrFloatBitWidth())
+      return rewriter.notifyMatchFailure(
+          op, "Narrowing cast may lead to data loss.");
+  }
+
+  return success();
+}
+
+template <TosaNarrowKind Kind, typename DenseAttrT, typename ElementTypeT,
+          typename MapValueFn>
+static Attribute convertDenseElementsAttr(ShapedType type, DenseAttrT attr,
+                                          const TypeConverter &typeConverter,
+                                          MapValueFn &&mapValueFn) {
+  const auto oldElementType = dyn_cast<ElementTypeT>(type.getElementType());
+  if (!oldElementType)
+    return attr;
+
+  if constexpr (std::is_same_v<ElementTypeT, IntegerType>) {
+    if (!isSourceInteger<Kind>(oldElementType))
+      return attr;
+  } else {
+    if (!isSourceFloat<Kind>(oldElementType))
+      return attr;
+  }
+
+  const auto newType =
+      dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+  if (!newType)
+    return attr;
+
+  const auto newElementType = dyn_cast<ElementTypeT>(newType.getElementType());
+  if (!newElementType)
+    return attr;
+
+  return attr.mapValues(newElementType, [&](const auto &value) {
+    return mapValueFn(newElementType, value);
+  });
+}
+
+// ---------------------------------------------------------------------------
+// Conversion patterns
+// ---------------------------------------------------------------------------
+
+// Applies the narrowing TypeConverter to a single TOSA op, including its
+// attributes and nested regions.
+template <TosaNarrowKind Kind>
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+                               ConversionPatternRewriter &rewriter,
+                               const TypeConverter *typeConverter) {
+  SmallVector<Type, 4> newResults;
+  if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+    return failure();
+
+  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                       newResults, {}, op->getSuccessors());
+
+  // Keep attribute payloads consistent with the converted element types.
+  for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+    const Attribute attribute = namedAttribute.getValue();
+
+    Attribute convertedAttr;
+    if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr)) {
+      state.addAttribute(namedAttribute.getName(), convertedAttr);
+      continue;
+    }
+
+    if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+      if (failed(convertAttributeWithTypeConverter<Kind>(
+              typeAttr, typeAttr.getValue(), typeConverter, convertedAttr)))
+        return rewriter.notifyMatchFailure(op,
+                                           "Failed to convert type attribute.");
+      state.addAttribute(namedAttribute.getName(), convertedAttr);
+      continue;
+    }
+
+    if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+      if (failed(convertAttributeWithTypeConverter<Kind>(
+              denseElementsAttr, denseElementsAttr.getType(), typeConverter,
+              convertedAttr)))
+        return rewriter.notifyMatchFailure(
+            op, "Failed to convert dense elements attribute.");
+      state.addAttribute(namedAttribute.getName(), convertedAttr);
+      continue;
+    }
+
+    state.addAttribute(namedAttribute.getName(), attribute);
+  }
+
+  for (Region &region : op->getRegions()) {
+    Region *newRegion = state.addRegion();
+    rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+    if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+      return failure();
+  }
+
+  Operation *newOp = rewriter.create(state);
+  rewriter.replaceOp(op, newOp->getResults());
+  return success();
+}
+
+template <TosaNarrowKind Kind>
+class ConvertGenericOp : public ConversionPattern {
+public:
+  ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!isa<tosa::TosaOp>(op))
+      return rewriter.notifyMatchFailure(
+          op,
+          "Support for operations other than TOSA has not been implemented.");
+
+    return convertGenericOp<Kind>(op, operands, rewriter, typeConverter);
+  }
+};
+
+template <typename OpTy, TosaNarrowKind Kind>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
+                                  this->getTypeConverter());
+  }
+};
+
+// ---------------------------------------------------------------------------
+// Kind-specific helpers and patterns
+// ---------------------------------------------------------------------------
+
+// Casts get extra checking so we only narrow when it is provably safe.
+template <TosaNarrowKind Kind>
+class ConvertCastOpWithBoundsChecking
+    : public OpConversionPattern<tosa::CastOp> {
+  using OpConversionPattern<tosa::CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::CastOp op, typename tosa::CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+    const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+    if (!inputType || !resultType)
+      return failure();
+
+    const TypeConverter *typeConverter = this->getTypeConverter();
+    if (failed(verifyCastDoesNotLosePrecision<Kind>(op, inputType, resultType,
+                                                    rewriter)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<tosa::CastOp>(
+        op, typeConverter->convertType(resultType), adaptor.getInput());
+    return success();
+  }
+};
+
+// ArgMax indices must fit the axis dimension, so we guard the integer rewrite.
+class ConvertArgMaxOpWithBoundsChecking
+    : public OpConversionPattern<tosa::ArgMaxOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ArgMaxOp op, typename tosa::ArgMaxOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    const int32_t axis = op.getAxis();
+    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+    if (!inputType || !inputType.isStaticDim(axis))
+      return rewriter.notifyMatchFailure(
+          op, "Requires a static axis dimension for bounds checking.");
+    const int64_t axisDim = inputType.getDimSize(axis);
+    if (axisDim >= std::numeric_limits<int32_t>::max())
+      return rewriter.notifyMatchFailure(
+          op, "Axis dimension is too large to narrow safely.");
+
+    const Type resultType = op.getOutput().getType();
+    const Type newResultType =
+        this->getTypeConverter()->convertType(resultType);
+    rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+                                                adaptor.getInput(), axis);
+    return success();
+  }
+};
+
+// Shared implementation for both narrowing passes; the mode decides which
+// element types and attribute payloads participate.
+template <TosaNarrowKind Kind>
+LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
+                               bool convertFunctionBoundaries) {
+  MLIRContext *context = op->getContext();
+
+  TypeConverter typeConverter;
+  typeConverter.addConversion([](Type type) -> Type { return type; });
+
+  typeConverter.addConversion(
+      [](IntegerType type) -> Type { return convertInteger<Kind>(type); });
+  typeConverter.addConversion(
+      [](FloatType type) -> Type { return convertFloat<Kind>(type); });
+  typeConverter.addConversion([&typeConverter](RankedTensorType type) -> Type {
+    Type elementType = type.getElementType();
+    if (!isSourceElement<Kind>(elementType))
+      return type;
+    Type converted = typeConverter.convertType(elementType);
+    if (!converted || converted == elementType)
+      return type;
+    return RankedTensorType::get(type.getShape(), converted,
+                                 type.getEncoding());
+  });
+  typeConverter.addConversion(
+      [&typeConverter](UnrankedTensorType type) -> Type {
+        Type elementType = type.getElementType();
+        if (!isSourceElement<Kind>(elementType))
+          return type;
+        Type converted = typeConverter.convertType(elementType);
+        if (!converted || converted == elementType)
+          return type;
+        return UnrankedTensorType::get(converted);
+      });
+
+  const auto materializeCast = [](OpBuilder &builder, Type resultType,
+                                  ValueRange inputs, Location loc) -> Value {
+    if (inputs.size() != 1)
+      return Value();
+    return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+  };
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+    typeConverter.addTypeAttributeConversion(
+        [](IntegerType type, IntegerAttr attribute) -> Attribute {
+          Attribute convertedAttr;
+          if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr))
+            return convertedAttr;
+          return attribute;
+        });
+    typeConverter.addTypeAttributeConversion([&typeConverter](
+                                                 ShapedType type,
+                                                 DenseIntElementsAttr attr)
+                                                 -> Attribute {
+      return convertDenseElementsAttr<Kind, DenseIntElementsAttr, IntegerType>(
+          type, attr, typeConverter,
+          [](IntegerType newElementType, const APInt &value) {
+            return value.truncSSat(newElementType.getWidth());
+          });
+    });
+  }
+
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+    typeConverter.addTypeAttributeConversion(
+        [](FloatType type, FloatAttr attribute) -> Attribute {
+          Attribute convertedAttr;
+          if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr))
+            return convertedAttr;
+          return attribute;
+        });
+    typeConverter.addTypeAttributeConversion(
+        [&typeConverter](ShapedType type,
+                         DenseFPElementsAttr attr) -> Attribute {
+          return convertDenseElementsAttr<Kind, DenseFPElementsAttr, FloatType>(
+              type, attr, typeConverter,
+              [](FloatType newElementType, const APFloat &value) {
+                APFloat converted(value);
+                bool losesInfo = false;
+                converted.convert(newElementType.getFloatSemantics(),
+                                  APFloat::rmNearestTiesToEven, &losesInfo);
+                return converted.bitcastToAPInt();
+              });
+        });
+  }
+
+  ConversionTarget target(*context);
+  target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+      [&typeConverter](Operation *op) {
+        return typeConverter.isLegal(op->getResultTypes()) &&
+               typeConverter.isLegal(op->getOperandTypes());
+      });
+  if (convertFunctionBoundaries) {
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [&typeConverter](func::FuncOp op) {
+          return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+                 typeConverter.isLegal(&op.getBody());
+        });
+    target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+      const FunctionType funcType =
+          op->getParentOfType<func::FuncOp>().getFunctionType();
+      return llvm::equal(op.getOperandTypes(), funcType.getResults());
+    });
+  } else {
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [](func::FuncOp) { return true; });
+    target.addDynamicallyLegalOp<func::ReturnOp>(
+        [](func::ReturnOp) { return true; });
+  }
+
+  RewritePatternSet patterns(context);
+  if (convertFunctionBoundaries) {
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+  }
+  if (aggressiveRewrite) {
+    patterns.add<ConvertGenericOp<Kind>>(typeConverter, context);
+  } else {
+    if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+      patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::ReshapeOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::ReverseOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::SliceOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::TileOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::TransposeOp, Kind>>(typeConverter,
+                                                          context);
+    patterns.add<ConvertTypedOp<tosa::IdentityOp, Kind>>(typeConverter,
+                                                         context);
+    patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
+  }
+  patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
+
+  if (failed(applyFullConversion(op, target, std::move(patterns))))
+    return failure();
+  return success();
+}
+
+// ---------------------------------------------------------------------------
+// Pass adapters that forward to the shared implementation
+// ---------------------------------------------------------------------------
+
+struct TosaNarrowI64ToI32
+    : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+  using Base = tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32>;
+
+  TosaNarrowI64ToI32() = default;
+
+  explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) {
+    this->aggressiveRewrite = options.aggressiveRewrite;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
+  void runOnOperation() override {
+    if (failed(runTosaNarrowing<TosaNarrowKind::Int64ToInt32>(
+            getOperation(), this->aggressiveRewrite,
+            this->convertFunctionBoundaries)))
+      signalPassFailure();
+  }
+};
+
+struct TosaNarrowF64ToF32
+    : public tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32> {
+  using Base = tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32>;
+
+  TosaNarrowF64ToF32() = default;
+
+  explicit TosaNarrowF64ToF32(const TosaNarrowF64ToF32PassOptions &options) {
+    this->aggressiveRewrite = options.aggressiveRewrite;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
+  void runOnOperation() override {
+    if (failed(runTosaNarrowing<TosaNarrowKind::Float64ToFloat32>(
+            getOperation(), this->aggressiveRewrite,
+            this->convertFunctionBoundaries)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
new file mode 100644
index 0000000000000..69547194dee3f
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_f64_add
+// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xf64>, %[[IN1:.*]]: tensor<13x21x3xf64>
+// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xf32>, %[[IN1:.*]]: tensor<13x21x3xf32>
+func.func @test_f64_add(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+  // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xf64>) -> tensor<13x21x1xf32>
+  // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+  // COMMON: %[[ADD:.*]] = tosa.add %{{.*}}, %{{.*}} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf64>
+  // DEFAULT: return %[[OUT]] : tensor<13x21x3xf64>
+  // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_regions
+// DEFAULT: %[[IN0:.*]]: tensor<1xf64>, %[[IN1:.*]]: tensor<1xf64>
+func.func @test_f64_regions(%arg0: tensor<1xf64>, %arg1: tensor<1xf64>, %arg2: tensor<i1>) -> tensor<1xf64> {
+  // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<1xf64>) -> tensor<1xf32>
+  // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<1xf64>) -> tensor<1xf32>
+  // COMMON: %[[IF:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf64> {
+    // COMMON: %[[ADD:.*]] = tosa.add %{{.*}}, %{{.*}} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+    %1 = tosa.add %arg0, %arg1 : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64>
+    tosa.yield %1 : tensor<1xf64>
+  } else {
+    // COMMON: %[[SUB:.*]] = tosa.sub %{{.*}}, %{{.*}} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+    %1 = tosa.sub %arg0, %arg1 : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64>
+    tosa.yield %1 : tensor<1xf64>
+  }
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF]] : (tensor<1xf32>) -> tensor<1xf64>
+  // DEFAULT: return %[[OUT]] : tensor<1xf64>
+  // FUNCBOUND: return %[[IF]] : tensor<1xf32>
+  return %0 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_convert_input_parameters
+// DEFAULT: %[[IN:.*]]: tensor<1x3xf64>
+// FUNCBOUND: %[[IN:.*]]: tensor<1x3xf32>
+func.func @test_convert_input_parameters(%arg0: tensor<1x3xf64>) -> tensor<1x3xf32> {
+  // DEFAULT: %[[CAST_IN:.*]] = tosa.cast %[[IN]] : (tensor<1x3xf64>) -> tensor<1x3xf32>
+  // DEFAULT: %[[IDENTITY:.*]] = tosa.identity %[[CAST_IN]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+  // FUNCBOUND: %[[IDENTITY:.*]] = tosa.identity %[[IN]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+  %0 = tosa.identity %arg0 : (tensor<1x3xf64>) -> tensor<1x3xf64>
+  // COMMON: %[[TO_F32:.*]] = tosa.cast %[[IDENTITY]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+  %1 = tosa.cast %0 : (tensor<1x3xf64>) -> tensor<1x3xf32>
+  // DEFAULT: return %[[TO_F32]] : tensor<1x3xf32>
+  // FUNCBOUND: return %[[TO_F32]] : tensor<1x3xf32>
+  return %1 : tensor<1x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const
+func.func @test_f64_const() -> tensor<2xf64> {
+  // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+  %0 = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}> : () -> tensor<2xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xf32>) -> tensor<2xf64>
+  // DEFAULT: return %[[OUT]] : tensor<2xf64>
+  // FUNCBOUND: return %[[CONST]] : tensor<2xf32>
+  return %0 : tensor<2xf64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
new file mode 100644
index 0000000000000..0588adaf20c6b
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_f64_identity_chain
+func.func @test_f64_identity_chain(%arg0: tensor<1xf64>) -> tensor<1xf64> {
+  // DEFAULT: %[[CAST_IN:.*]] = tosa.cast %arg0 : (tensor<1xf64>) -> tensor<1xf32>
+  // DEFAULT: %[[ID1:.*]] = tosa.identity %[[CAST_IN]] : (tensor<1xf32>) -> tensor<1xf32>
+  // FUNCBOUND: %[[ID1:.*]] = tosa.identity %arg0 : (tensor<1xf32>) -> tensor<1xf32>
+  %0 = tosa.identity %arg0 : (tensor<1xf64>) -> tensor<1xf64>
+  // COMMON: %[[ID2:.*]] = tosa.identity %[[ID1]] : (tensor<1xf32>) -> tensor<1xf32>
+  %1 = tosa.identity %0 : (tensor<1xf64>) -> tensor<1xf64>
+  // DEFAULT: %[[CAST_OUT:.*]] = tosa.cast %[[ID2]] : (tensor<1xf32>) -> tensor<1xf64>
+  // DEFAULT: return %[[CAST_OUT]] : tensor<1xf64>
+  // FUNCBOUND: return %[[ID2]] : tensor<1xf32>
+  return %1 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const
+func.func @test_f64_const() -> tensor<2xf64> {
+  // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+  %0 = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}> : () -> tensor<2xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xf32>) -> tensor<2xf64>
+  // DEFAULT: return %[[OUT]] : tensor<2xf64>
+  // FUNCBOUND: return %[[CONST]] : tensor<2xf32>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_concat
+// DEFAULT: %[[A0:.*]]: tensor<13x21x3xf64>, %[[A1:.*]]: tensor<13x21x3xf64>
+// FUNCBOUND: %[[A0:.*]]: tensor<13x21x3xf32>, %[[A1:.*]]: tensor<13x21x3xf32>
+func.func @test_f64_concat(%arg0: tensor<13x21x3xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<26x21x3xf64> {
+  // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[A0]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+  // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[A1]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+  // COMMON: %[[CONCAT:.*]] = tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf64>, tensor<13x21x3xf64>) -> tensor<26x21x3xf64>
+  // DEFAULT: %[[CAST_OUT:.*]] = tosa.cast %[[CONCAT]] : (tensor<26x21x3xf32>) -> tensor<26x21x3xf64>
+  // DEFAULT: return %[[CAST_OUT]] : tensor<26x21x3xf64>
+  // FUNCBOUND: return %[[CONCAT]] : tensor<26x21x3xf32>
+  return %0 : tensor<26x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_pad
+func.func @test_f64_pad(%arg0: tensor<13x21x3xf64>, %arg1: tensor<1xf64>) -> tensor<15x23x5xf64> {
+  %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // DEFAULT-DAG: %[[IN_CAST:.*]] = tosa.cast %arg0 : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+  // DEFAULT-DAG: %[[PAD_CAST:.*]] = tosa.cast %arg1 : (tensor<1xf64>) -> tensor<1xf32>
+  // COMMON: %[[PAD:.*]] = tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<15x23x5xf32>
+  %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xf64>, !tosa.shape<6>, tensor<1xf64>) -> tensor<15x23x5xf64>
+  // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[PAD]] : (tensor<15x23x5xf32>) -> tensor<15x23x5xf64>
+  // DEFAULT: return %[[OUT_CAST]] : tensor<15x23x5xf64>
+  // FUNCBOUND: return %[[PAD]] : tensor<15x23x5xf32>
+  return %1 : tensor<15x23x5xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_reshape
+func.func @test_f64_reshape(%arg0: tensor<13x21x3xf64>) -> tensor<1x819xf64> {
+  %shape = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // COMMON: %[[RESHAPE:.*]] = tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+  %0 = tosa.reshape %arg0, %shape : (tensor<13x21x3xf64>, !tosa.shape<2>) -> tensor<1x819xf64>
+  // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[RESHAPE]] : (tensor<1x819xf32>) -> tensor<1x819xf64>
+  // DEFAULT: return %[[OUT_CAST]] : tensor<1x819xf64>
+  // FUNCBOUND: return %[[RESHAPE]] : tensor<1x819xf32>
+  return %0 : tensor<1x819xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_reverse
+func.func @test_f64_reverse(%arg0: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+  // COMMON: %[[REV:.*]] = tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[REV]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf64>
+  // DEFAULT: return %[[OUT]] : tensor<13x21x3xf64>
+  // FUNCBOUND: return %[[REV]] : tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_slice
+func.func @test_f64_slice(%arg0: tensor<13x21x3xf64>) -> tensor<4x11x1xf64> {
+  %size = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %start = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // COMMON: %[[SLICE:.*]] = tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
+  %0 = tosa.slice %arg0, %size, %start : (tensor<13x21x3xf64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[SLICE]] : (tensor<4x11x1xf32>) -> tensor<4x11x1xf64>
+  // DEFAULT: return %[[OUT]] : tensor<4x11x1xf64>
+  // FUNCBOUND: return %[[SLICE]] : tensor<4x11x1xf32>
+  return %0 : tensor<4x11x1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_tile
+func.func @test_f64_tile(%arg0: tensor<13x21x3xf64>) -> tensor<39x21x6xf64> {
+  %multipliers = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  // COMMON: %[[TILE:.*]] = tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
+  %0 = tosa.tile %arg0, %multipliers : (tensor<13x21x3xf64>, !tosa.shape<3>) -> tensor<39x21x6xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[TILE]] : (tensor<39x21x6xf32>) -> tensor<39x21x6xf64>
+  // DEFAULT: return %[[OUT]] : tensor<39x21x6xf64>
+  // FUNCBOUND: return %[[TILE]] : tensor<39x21x6xf32>
+  return %0 : tensor<39x21x6xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_transpose
+func.func @test_f64_transpose(%arg0: tensor<13x21x3xf64>) -> tensor<3x13x21xf64> {
+  // COMMON: %[[TRANSPOSE:.*]] = tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
+  %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf64>) -> tensor<3x13x21xf64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[TRANSPOSE]] : (tensor<3x13x21xf32>) -> tensor<3x13x21xf64>
+  // DEFAULT: return %[[OUT]] : tensor<3x13x21xf64>
+  // FUNCBOUND: return %[[TRANSPOSE]] : tensor<3x13x21xf32>
+  return %0 : tensor<3x13x21xf64>
+}
+
+// -----
+
+module {
+// CHECK-LABEL: test_f64_regions
+func.func @test_f64_regions(%arg0: tensor<1xf64>, %arg1: tensor<1xf64>, %arg2: tensor<i1>) -> tensor<1xf64> {
+  // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf32>
+  %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf64> {
+    // COMMON: %[[ID0:.*]] = tosa.identity %{{.*}} : (tensor<1xf32>) -> tensor<1xf32>
+    %1 = tosa.identity %arg0 : (tensor<1xf64>) -> tensor<1xf64>
+    // COMMON: tosa.yield %[[ID0]] : tensor<1xf32>
+    tosa.yield %1 : tensor<1xf64>
+  } else {
+    // COMMON: %[[ID1:.*]] = tosa.identity %{{.*}} : (tensor<1xf32>) -> tensor<1xf32>
+    %1 = tosa.identity %arg1 : (tensor<1xf64>) -> tensor<1xf64>
+    // COMMON: tosa.yield %[[ID1]] : tensor<1xf32>
+    tosa.yield %1 : tensor<1xf64>
+  }
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<1xf32>) -> tensor<1xf64>
+  // DEFAULT: return %[[OUT]] : tensor<1xf64>
+  // FUNCBOUND: return %[[IF_RESULT]] : tensor<1xf32>
+  return %0 : tensor<1xf64>
+}
+}
+
+// -----
+
+module {
+// CHECK-LABEL: test_f64_add_diagnostic
+func.func @test_f64_add_diagnostic(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+  // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+  %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+  return %0 : tensor<13x21x3xf64>
+}
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
index a14483fcdd7b0..fcaf53b0cc15c 100644
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -15,6 +15,18 @@ func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi
 
 // -----
 
+// CHECK-LABEL: test_i64_const
+func.func @test_i64_const() -> tensor<2xi64> {
+  // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+  // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
+  // DEFAULT: return %[[OUT]] : tensor<2xi64>
+  // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
+  return %0 : tensor<2xi64>
+}
+
+// -----
+
 // CHECK-LABEL: test_i64_argmax_cast
 func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> {
   // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>

>From f60d17f1bea59a6dddd18f423b69f7636c32a5b8 Mon Sep 17 00:00:00 2001
From: Vitalii Shutov <vitalii.shutov at arm.com>
Date: Thu, 4 Dec 2025 17:38:50 +0000
Subject: [PATCH 2/3] [mlir][tosa] Extend narrowing pass

Co-authored-by: Luke Hutton <Luke.Hutton at arm.com>
Change-Id: I6bf220a821c45137fe0b447b6ddbfb83e9f63a39
---
 .../Tosa/Transforms/TosaNarrowTypes.cpp       | 377 ++++++++++++------
 .../tosa-narrow-i64-to-i32-aggressive.mlir    |   9 +
 .../Dialect/Tosa/tosa-narrow-i64-to-i32.mlir  |   9 +
 3 files changed, 273 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
index 0895bb3905d67..d9651f7321269 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
@@ -45,21 +45,21 @@ enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
 // ---------------------------------------------------------------------------
 
 template <TosaNarrowKind Kind>
-static bool isSourceInteger(IntegerType type) {
+bool isSourceInteger(IntegerType type) {
   if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
     return type.isInteger(64);
   return false;
 }
 
 template <TosaNarrowKind Kind>
-static bool isSourceFloat(FloatType type) {
+bool isSourceFloat(FloatType type) {
   if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
     return type.isF64();
   return false;
 }
 
 template <TosaNarrowKind Kind>
-static Type convertInteger(IntegerType type) {
+Type convertInteger(IntegerType type) {
   if (!isSourceInteger<Kind>(type))
     return type;
   if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
@@ -68,7 +68,7 @@ static Type convertInteger(IntegerType type) {
 }
 
 template <TosaNarrowKind Kind>
-static Type convertFloat(FloatType type) {
+Type convertFloat(FloatType type) {
   if (!isSourceFloat<Kind>(type))
     return type;
   if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
@@ -77,7 +77,7 @@ static Type convertFloat(FloatType type) {
 }
 
 template <TosaNarrowKind Kind>
-static bool isSourceElement(Type type) {
+bool isSourceElement(Type type) {
   if (auto intTy = dyn_cast<IntegerType>(type))
     return isSourceInteger<Kind>(intTy);
   if (auto floatTy = dyn_cast<FloatType>(type))
@@ -86,7 +86,7 @@ static bool isSourceElement(Type type) {
 }
 
 template <TosaNarrowKind Kind>
-static Type convertElement(Type type) {
+Type convertElement(Type type) {
   if (auto intTy = dyn_cast<IntegerType>(type))
     return convertInteger<Kind>(intTy);
   if (auto floatTy = dyn_cast<FloatType>(type))
@@ -95,73 +95,168 @@ static Type convertElement(Type type) {
 }
 
 template <TosaNarrowKind Kind>
-static bool typeNeedsConversion(Type type) {
+bool typeNeedsConversion(Type type) {
   if (auto shaped = dyn_cast<ShapedType>(type))
     return isSourceElement<Kind>(shaped.getElementType());
   return isSourceElement<Kind>(type);
 }
 
+FailureOr<APInt> convertIntegerConstant(IntegerType targetType,
+                                        const APInt &value,
+                                        bool allowLossyConversion) {
+  const unsigned targetWidth = targetType.getWidth();
+  if (!allowLossyConversion && !value.isSignedIntN(targetWidth))
+    return failure();
+
+  if (allowLossyConversion)
+    return value.truncSSat(targetWidth);
+  return value.sextOrTrunc(targetWidth);
+}
+
+FailureOr<APFloat> convertFloatConstant(FloatType targetType,
+                                        const APFloat &value,
+                                        bool allowLossyConversion) {
+  APFloat converted(value);
+  bool losesInfo = false;
+  converted.convert(targetType.getFloatSemantics(),
+                    APFloat::rmNearestTiesToEven, &losesInfo);
+  if (!allowLossyConversion && losesInfo)
+    return failure();
+  return converted;
+}
+
 // Narrows scalar constant attributes so they keep matching the converted
 // element types.
 template <TosaNarrowKind Kind>
-static bool tryConvertScalarAttribute(Attribute attribute,
-                                      Attribute &resultAttr) {
+FailureOr<Attribute> tryConvertScalarAttribute(Attribute attribute,
+                                               bool allowLossyConversion) {
   if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
     if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
       if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
           intType && isSourceInteger<Kind>(intType)) {
         const auto convertedType =
             cast<IntegerType>(convertInteger<Kind>(intType));
-        const APInt truncated =
-            intAttr.getValue().truncSSat(convertedType.getWidth());
-        resultAttr = IntegerAttr::get(convertedType, truncated);
-        return true;
+        FailureOr<APInt> convertedValue = convertIntegerConstant(
+            convertedType, intAttr.getValue(), allowLossyConversion);
+        if (failed(convertedValue))
+          return failure();
+        return IntegerAttr::get(convertedType, convertedValue.value());
       }
     }
-  }
-
-  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+  } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
     if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
       if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
           floatType && isSourceFloat<Kind>(floatType)) {
         const auto convertedType =
             cast<FloatType>(convertFloat<Kind>(floatType));
-        APFloat value = floatAttr.getValue();
-        bool losesInfo = false;
-        value.convert(convertedType.getFloatSemantics(),
-                      APFloat::rmNearestTiesToEven, &losesInfo);
-        resultAttr = FloatAttr::get(convertedType, value);
-        return true;
+        FailureOr<APFloat> convertedValue = convertFloatConstant(
+            convertedType, floatAttr.getValue(), allowLossyConversion);
+        if (failed(convertedValue))
+          return failure();
+        return FloatAttr::get(convertedType, convertedValue.value());
       }
     }
   }
 
-  return false;
+  return attribute;
+}
+
+template <TosaNarrowKind Kind>
+FailureOr<Attribute>
+convertDenseIntElementsAttr(ShapedType type, DenseIntElementsAttr attr,
+                            const TypeConverter &typeConverter,
+                            bool allowLossyConversion) {
+  if constexpr (Kind != TosaNarrowKind::Int64ToInt32)
+    return attr;
+
+  const auto oldElementType = dyn_cast<IntegerType>(type.getElementType());
+  if (!oldElementType || !isSourceInteger<Kind>(oldElementType))
+    return attr;
+
+  const auto newType =
+      dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+  if (!newType)
+    return failure();
+
+  const auto newElementType = dyn_cast<IntegerType>(newType.getElementType());
+  if (!newElementType)
+    return failure();
+
+  if (!allowLossyConversion) {
+    for (APInt value : attr.getValues<APInt>())
+      if (failed(convertIntegerConstant(newElementType, value,
+                                        /*allowLossyConversion=*/false)))
+        return failure();
+  }
+
+  Attribute convertedAttr =
+      attr.mapValues(newElementType, [&](const APInt &value) -> APInt {
+        return convertIntegerConstant(newElementType, value,
+                                      /*allowLossyConversion=*/true)
+            .value();
+      });
+  return convertedAttr;
+}
+
+template <TosaNarrowKind Kind>
+FailureOr<Attribute>
+convertDenseFPElementsAttr(ShapedType type, DenseFPElementsAttr attr,
+                           const TypeConverter &typeConverter,
+                           bool allowLossyConversion) {
+  if constexpr (Kind != TosaNarrowKind::Float64ToFloat32)
+    return attr;
+
+  const auto oldElementType = dyn_cast<FloatType>(type.getElementType());
+  if (!oldElementType || !isSourceFloat<Kind>(oldElementType))
+    return attr;
+
+  const auto newType =
+      dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+  if (!newType)
+    return failure();
+
+  const auto newElementType = dyn_cast<FloatType>(newType.getElementType());
+  if (!newElementType)
+    return failure();
+
+  if (!allowLossyConversion) {
+    for (APFloat value : attr.getValues<APFloat>())
+      if (failed(convertFloatConstant(newElementType, value,
+                                      /*allowLossyConversion=*/false)))
+        return failure();
+  }
+
+  Attribute convertedAttr =
+      attr.mapValues(newElementType, [&](const APFloat &value) -> APInt {
+        APFloat converted = convertFloatConstant(newElementType, value,
+                                                 /*allowLossyConversion=*/true)
+                                .value();
+        // DenseFPElementsAttr stores each float as raw bits, so emit the APInt
+        // representation that MLIR expects in the underlying buffer.
+        return converted.bitcastToAPInt();
+      });
+  return convertedAttr;
 }
 
 template <TosaNarrowKind Kind, typename AttrT>
-static LogicalResult
+FailureOr<Attribute>
 convertAttributeWithTypeConverter(AttrT attr, Type type,
-                                  const TypeConverter *typeConverter,
-                                  Attribute &resultAttr) {
-  if (!typeNeedsConversion<Kind>(type)) {
-    resultAttr = attr;
-    return success();
-  }
+                                  const TypeConverter *typeConverter) {
+  if (!typeNeedsConversion<Kind>(type))
+    return attr;
 
   const std::optional<Attribute> convertedAttribute =
       typeConverter->convertTypeAttribute(type, attr);
   if (!convertedAttribute)
     return failure();
 
-  resultAttr = convertedAttribute.value();
-  return success();
+  return convertedAttribute.value();
 }
 
 // Rejects cast rewrites that would lose precision (unless aggressive mode is
 // enabled).
 template <TosaNarrowKind Kind>
-static LogicalResult
+LogicalResult
 verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
                                ShapedType resultType,
                                ConversionPatternRewriter &rewriter) {
@@ -174,9 +269,7 @@ verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
         elementInputIntType.getWidth() > elementResultIntType.getWidth())
       return rewriter.notifyMatchFailure(
           op, "Narrowing cast may lead to data loss.");
-  }
-
-  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+  } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
     const auto elementInputFloatType =
         dyn_cast<FloatType>(inputType.getElementType());
     const auto elementResultFloatType =
@@ -191,37 +284,6 @@ verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
   return success();
 }
 
-template <TosaNarrowKind Kind, typename DenseAttrT, typename ElementTypeT,
-          typename MapValueFn>
-static Attribute convertDenseElementsAttr(ShapedType type, DenseAttrT attr,
-                                          const TypeConverter &typeConverter,
-                                          MapValueFn &&mapValueFn) {
-  const auto oldElementType = dyn_cast<ElementTypeT>(type.getElementType());
-  if (!oldElementType)
-    return attr;
-
-  if constexpr (std::is_same_v<ElementTypeT, IntegerType>) {
-    if (!isSourceInteger<Kind>(oldElementType))
-      return attr;
-  } else {
-    if (!isSourceFloat<Kind>(oldElementType))
-      return attr;
-  }
-
-  const auto newType =
-      dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
-  if (!newType)
-    return attr;
-
-  const auto newElementType = dyn_cast<ElementTypeT>(newType.getElementType());
-  if (!newElementType)
-    return attr;
-
-  return attr.mapValues(newElementType, [&](const auto &value) {
-    return mapValueFn(newElementType, value);
-  });
-}
-
 // ---------------------------------------------------------------------------
 // Conversion patterns
 // ---------------------------------------------------------------------------
@@ -231,7 +293,8 @@ static Attribute convertDenseElementsAttr(ShapedType type, DenseAttrT attr,
 template <TosaNarrowKind Kind>
 LogicalResult convertGenericOp(Operation *op, ValueRange operands,
                                ConversionPatternRewriter &rewriter,
-                               const TypeConverter *typeConverter) {
+                               const TypeConverter *typeConverter,
+                               bool allowLossyConversion) {
   SmallVector<Type, 4> newResults;
   if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
     return failure();
@@ -243,28 +306,37 @@ LogicalResult convertGenericOp(Operation *op, ValueRange operands,
   for (const NamedAttribute &namedAttribute : op->getAttrs()) {
     const Attribute attribute = namedAttribute.getValue();
 
-    Attribute convertedAttr;
-    if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr)) {
-      state.addAttribute(namedAttribute.getName(), convertedAttr);
+    if (isa<IntegerAttr>(attribute) || isa<FloatAttr>(attribute)) {
+      FailureOr<Attribute> convertedAttr =
+          tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+      if (failed(convertedAttr))
+        return rewriter.notifyMatchFailure(
+            op, "Scalar attribute narrowing would lose precision; enable "
+                "aggressive rewrite to override.");
+      state.addAttribute(namedAttribute.getName(), convertedAttr.value());
       continue;
     }
 
     if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
-      if (failed(convertAttributeWithTypeConverter<Kind>(
-              typeAttr, typeAttr.getValue(), typeConverter, convertedAttr)))
+      FailureOr<Attribute> convertedAttr =
+          convertAttributeWithTypeConverter<Kind>(typeAttr, typeAttr.getValue(),
+                                                  typeConverter);
+      if (failed(convertedAttr))
         return rewriter.notifyMatchFailure(op,
                                            "Failed to convert type attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttr);
+      state.addAttribute(namedAttribute.getName(), convertedAttr.value());
       continue;
     }
 
     if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
-      if (failed(convertAttributeWithTypeConverter<Kind>(
-              denseElementsAttr, denseElementsAttr.getType(), typeConverter,
-              convertedAttr)))
+      FailureOr<Attribute> convertedAttr =
+          convertAttributeWithTypeConverter<Kind>(
+              denseElementsAttr, denseElementsAttr.getType(), typeConverter);
+      if (failed(convertedAttr))
         return rewriter.notifyMatchFailure(
-            op, "Failed to convert dense elements attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttr);
+            op, "Failed to convert dense elements attribute without precision "
+                "loss; enable aggressive rewrite to override.");
+      state.addAttribute(namedAttribute.getName(), convertedAttr.value());
       continue;
     }
 
@@ -286,8 +358,10 @@ LogicalResult convertGenericOp(Operation *op, ValueRange operands,
 template <TosaNarrowKind Kind>
 class ConvertGenericOp : public ConversionPattern {
 public:
-  ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
-      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+  ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context,
+                   bool allowLossyConversion)
+      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context),
+        allowLossyConversion(allowLossyConversion) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -297,19 +371,26 @@ class ConvertGenericOp : public ConversionPattern {
           op,
           "Support for operations other than TOSA has not been implemented.");
 
-    return convertGenericOp<Kind>(op, operands, rewriter, typeConverter);
+    return convertGenericOp<Kind>(op, operands, rewriter, typeConverter,
+                                  allowLossyConversion);
   }
+
+private:
+  const bool allowLossyConversion;
 };
 
 template <typename OpTy, TosaNarrowKind Kind>
 class ConvertTypedOp : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
+public:
+  ConvertTypedOp(TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<OpTy>(typeConverter, context) {}
 
   LogicalResult
   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
-                                  this->getTypeConverter());
+                                  this->getTypeConverter(),
+                                  /*allowLossyConversion=*/false);
   }
 };
 
@@ -317,7 +398,7 @@ class ConvertTypedOp : public OpConversionPattern<OpTy> {
 // Kind-specific helpers and patterns
 // ---------------------------------------------------------------------------
 
-// Casts get extra checking so we only narrow when it is provably safe.
+// Casts get extra checking so we only narrow when it is probably safe.
 template <TosaNarrowKind Kind>
 class ConvertCastOpWithBoundsChecking
     : public OpConversionPattern<tosa::CastOp> {
@@ -369,12 +450,57 @@ class ConvertArgMaxOpWithBoundsChecking
   }
 };
 
+template <TosaNarrowKind Kind>
+class ConvertClampOpWithBoundsChecking
+    : public OpConversionPattern<tosa::ClampOp> {
+  static_assert(Kind == TosaNarrowKind::Int64ToInt32,
+                "Clamp bounds checking only supported for integer narrowing");
+  using OpConversionPattern<tosa::ClampOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ClampOp op, typename tosa::ClampOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
+    auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
+    if (!minAttr || !maxAttr)
+      return rewriter.notifyMatchFailure(
+          op, "Clamp attributes must be integer constants.");
+
+    const int64_t min = minAttr.getInt();
+    const int64_t max = maxAttr.getInt();
+    if (min < std::numeric_limits<int32_t>::min() ||
+        max > std::numeric_limits<int32_t>::max())
+      return rewriter.notifyMatchFailure(
+          op, "Clamp bounds exceed int32 range. Narrowing may lose data.");
+
+    const Type resultType = op.getOutput().getType();
+    const Type newResultType =
+        this->getTypeConverter()->convertType(resultType);
+    const auto newResultShaped = dyn_cast<ShapedType>(newResultType);
+    if (!newResultShaped)
+      return failure();
+    const auto newElementType =
+        dyn_cast<IntegerType>(newResultShaped.getElementType());
+    if (!newElementType)
+      return failure();
+
+    const IntegerAttr newMinAttr = IntegerAttr::get(newElementType, min);
+    const IntegerAttr newMaxAttr = IntegerAttr::get(newElementType, max);
+
+    rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
+                                               adaptor.getInput(), newMinAttr,
+                                               newMaxAttr, op.getNanModeAttr());
+    return success();
+  }
+};
+
 // Shared implementation for both narrowing passes; the mode decides which
 // element types and attribute payloads participate.
 template <TosaNarrowKind Kind>
 LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
                                bool convertFunctionBoundaries) {
   MLIRContext *context = op->getContext();
+  const bool allowLossyConversion = aggressiveRewrite;
 
   TypeConverter typeConverter;
   typeConverter.addConversion([](Type type) -> Type { return type; });
@@ -415,44 +541,47 @@ LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
 
   if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
     typeConverter.addTypeAttributeConversion(
-        [](IntegerType type, IntegerAttr attribute) -> Attribute {
-          Attribute convertedAttr;
-          if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr))
-            return convertedAttr;
-          return attribute;
+        [allowLossyConversion](IntegerType /*type*/, IntegerAttr attribute)
+            -> TypeConverter::AttributeConversionResult {
+          FailureOr<Attribute> converted =
+              tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+          if (failed(converted))
+            return TypeConverter::AttributeConversionResult::abort();
+          return TypeConverter::AttributeConversionResult::result(
+              converted.value());
         });
-    typeConverter.addTypeAttributeConversion([&typeConverter](
-                                                 ShapedType type,
-                                                 DenseIntElementsAttr attr)
-                                                 -> Attribute {
-      return convertDenseElementsAttr<Kind, DenseIntElementsAttr, IntegerType>(
-          type, attr, typeConverter,
-          [](IntegerType newElementType, const APInt &value) {
-            return value.truncSSat(newElementType.getWidth());
-          });
-    });
-  }
-
-  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
     typeConverter.addTypeAttributeConversion(
-        [](FloatType type, FloatAttr attribute) -> Attribute {
-          Attribute convertedAttr;
-          if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr))
-            return convertedAttr;
-          return attribute;
+        [&typeConverter, allowLossyConversion](ShapedType type,
+                                               DenseIntElementsAttr attr)
+            -> TypeConverter::AttributeConversionResult {
+          FailureOr<Attribute> converted = convertDenseIntElementsAttr<Kind>(
+              type, attr, typeConverter, allowLossyConversion);
+          if (failed(converted))
+            return TypeConverter::AttributeConversionResult::abort();
+          return TypeConverter::AttributeConversionResult::result(
+              converted.value());
         });
+  } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
     typeConverter.addTypeAttributeConversion(
-        [&typeConverter](ShapedType type,
-                         DenseFPElementsAttr attr) -> Attribute {
-          return convertDenseElementsAttr<Kind, DenseFPElementsAttr, FloatType>(
-              type, attr, typeConverter,
-              [](FloatType newElementType, const APFloat &value) {
-                APFloat converted(value);
-                bool losesInfo = false;
-                converted.convert(newElementType.getFloatSemantics(),
-                                  APFloat::rmNearestTiesToEven, &losesInfo);
-                return converted.bitcastToAPInt();
-              });
+        [allowLossyConversion](FloatType /*type*/, FloatAttr attribute)
+            -> TypeConverter::AttributeConversionResult {
+          FailureOr<Attribute> converted =
+              tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+          if (failed(converted))
+            return TypeConverter::AttributeConversionResult::abort();
+          return TypeConverter::AttributeConversionResult::result(
+              converted.value());
+        });
+    typeConverter.addTypeAttributeConversion(
+        [&typeConverter, allowLossyConversion](ShapedType type,
+                                               DenseFPElementsAttr attr)
+            -> TypeConverter::AttributeConversionResult {
+          FailureOr<Attribute> converted = convertDenseFPElementsAttr<Kind>(
+              type, attr, typeConverter, allowLossyConversion);
+          if (failed(converted))
+            return TypeConverter::AttributeConversionResult::abort();
+          return TypeConverter::AttributeConversionResult::result(
+              converted.value());
         });
   }
 
@@ -487,10 +616,14 @@ LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
     populateReturnOpTypeConversionPattern(patterns, typeConverter);
   }
   if (aggressiveRewrite) {
-    patterns.add<ConvertGenericOp<Kind>>(typeConverter, context);
+    patterns.add<ConvertGenericOp<Kind>>(typeConverter, context,
+                                         allowLossyConversion);
   } else {
-    if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+    if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
       patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+      patterns.add<ConvertClampOpWithBoundsChecking<Kind>>(typeConverter,
+                                                           context);
+    }
     patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
     patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
     patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
@@ -505,8 +638,8 @@ LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
     patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
     patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
     patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
+    patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
   }
-  patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
 
   if (failed(applyFullConversion(op, target, std::move(patterns))))
     return failure();
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
index 1a36177a37033..9848fe4abb345 100644
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
@@ -79,3 +79,12 @@ func.func @test_const() -> tensor<2xi64> {
   // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
   return %0 : tensor<2xi64>
 }
+
+// -----
+
+// CHECK-LABEL: test_clamp_trunc
+func.func @test_clamp_trunc(%arg0: tensor<100xi64>) -> tensor<100xi64> {
+  // COMMON: tosa.clamp %{{.*}} {max_val = 2147483647 : i32, min_val = -2147483648 : i32} : (tensor<100xi32>) -> tensor<100xi32>
+  %1 = tosa.clamp %arg0 {max_val = 3000000000 : i64, min_val = -2147483648 : i64} : (tensor<100xi64>) -> tensor<100xi64>
+  return %1 : tensor<100xi64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
index fcaf53b0cc15c..809d10094be81 100644
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -172,3 +172,12 @@ func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> {
   // COMMON: return %[[OUT_CAST]] : tensor<1xi32>
   return %2 : tensor<1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_clamp
+func.func @test_clamp(%arg0: tensor<100xi64>) -> tensor<100xi64> {
+  // COMMON: tosa.clamp %{{.*}} {max_val = 2147483647 : i32, min_val = -2147483648 : i32} : (tensor<100xi32>) -> tensor<100xi32>
+  %1 = tosa.clamp %arg0 {max_val = 2147483647 : i64, min_val = -2147483648 : i64} : (tensor<100xi64>) -> tensor<100xi64>
+  return %1 : tensor<100xi64>
+}

>From fd6b302adec3576261aec501e6dc78168773bd4e Mon Sep 17 00:00:00 2001
From: Vitalii Shutov <vitalii.shutov at arm.com>
Date: Mon, 15 Dec 2025 10:47:58 +0000
Subject: [PATCH 3/3] add more lit

Change-Id: Ia4ae84838da0a6d4fc6ddadad5b13aabec891de0
---
 .../Dialect/Tosa/tosa-narrow-f64-to-f32.mlir  | 20 +++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
index 0588adaf20c6b..1034ee67f65e2 100644
--- a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
@@ -31,6 +31,26 @@ func.func @test_f64_const() -> tensor<2xf64> {
 
 // -----
 
+// CHECK-LABEL: test_f64_const_precision_loss
+func.func @test_f64_const_precision_loss() -> tensor<1xf64> {
+  // expected-error @+2 {{failed to legalize operation 'tosa.const'}}
+  // 2^24 + 1 fits in f64 but rounds to 2^24 in f32.
+  %0 = "tosa.const"() <{values = dense<16777217.0> : tensor<1xf64>}> : () -> tensor<1xf64>
+  return %0 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const_precision_loss_small
+func.func @test_f64_const_precision_loss_small() -> tensor<1xf64> {
+  // expected-error @+2 {{failed to legalize operation 'tosa.const'}}
+  // Too small: underflows to zero when narrowed to f32.
+  %0 = "tosa.const"() <{values = dense<1.0e-46> : tensor<1xf64>}> : () -> tensor<1xf64>
+  return %0 : tensor<1xf64>
+}
+
+// -----
+
 // CHECK-LABEL: test_f64_concat
 // DEFAULT: %[[A0:.*]]: tensor<13x21x3xf64>, %[[A1:.*]]: tensor<13x21x3xf64>
 // FUNCBOUND: %[[A0:.*]]: tensor<13x21x3xf32>, %[[A1:.*]]: tensor<13x21x3xf32>



More information about the Mlir-commits mailing list