[Mlir-commits] [mlir] a08a55c - [mlir][tosa] Extend narrowing pass (#170712)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 16 02:11:18 PST 2025
Author: Vitalii Shutov
Date: 2025-12-16T10:11:13Z
New Revision: a08a55cab9ee7e36c1ce7c3286c6fc0957f648a8
URL: https://github.com/llvm/llvm-project/commit/a08a55cab9ee7e36c1ce7c3286c6fc0957f648a8
DIFF: https://github.com/llvm/llvm-project/commit/a08a55cab9ee7e36c1ce7c3286c6fc0957f648a8.diff
LOG: [mlir][tosa] Extend narrowing pass (#170712)
- unify the i64->i32 and f64->f32 narrowing logic inside the shared
implementation
- register tosa::ConstOp in the non-aggressive rewrite set so standalone
constants are narrowed
---------
Signed-off-by: Vitalii Shutov <vitalii.shutov at arm.com>
Co-authored-by: Luke Hutton <Luke.Hutton at arm.com>
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
Modified:
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
Removed:
mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 12f520297b702..4a5f283bc66c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -198,4 +198,29 @@ def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
];
}
+def TosaNarrowF64ToF32Pass : Pass<"tosa-narrow-f64-to-f32", "func::FuncOp"> {
+ let summary = "Narrow F64 TOSA operations to F32";
+ let description = [{
+ This pass narrows TOSA operations with 64-bit floating-point tensor types to
+ 32-bit floating-point tensor types. While TOSA itself has no double
+ precision support, upstream conversions or frontends may still materialize
+ F64 tensors temporarily, so this pass removes them before handing off to a
+ TOSA backend.
+ }];
+
+ let options = [
+ Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
+ "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
+ "is safe. This option may lead to data loss if not used carefully.">,
+ Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
+ "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+ "be inserted at the I/O boundaries.">
+ ];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 091b481d6394b..0ff68b1bb54f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
- TosaNarrowI64ToI32.cpp
+ TosaNarrowTypes.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
deleted file mode 100644
index be442cc4f88ca..0000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
+++ /dev/null
@@ -1,348 +0,0 @@
-//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass narrows TOSA operations with 64-bit integer tensor types to
-// 32-bit integer tensor types. This can be useful for backends that do not
-// support the EXT-INT64 extension of TOSA. The pass has two options:
-//
-// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
-// regardless or whether the narrowing is safe. This option may lead to
-// data loss if not used carefully.
-// - convert-function-boundaries - If enabled, the pass will convert function
-// I/O types as well. Otherwise casts will be inserted at the I/O
-// boundaries.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/IR/Verifier.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace tosa {
-#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
-#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
-} // namespace tosa
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-LogicalResult convertGenericOp(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter,
- const TypeConverter *typeConverter) {
- // Convert types of results
- SmallVector<Type, 4> newResults;
- if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
- return failure();
-
- // Create a new operation state
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, {}, op->getSuccessors());
-
- for (const NamedAttribute &namedAttribute : op->getAttrs()) {
- const Attribute attribute = namedAttribute.getValue();
-
- // Convert integer attribute type
- if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
- Type type = typeAttr.getValue();
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(type, attribute);
- if (!convertedAttribute)
- return rewriter.notifyMatchFailure(op,
- "Failed to convert type attribute.");
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
- const Type type = denseElementsAttr.getType();
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(type, denseElementsAttr);
- if (!convertedAttribute)
- return rewriter.notifyMatchFailure(
- op, "Failed to convert dense elements attribute.");
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- state.addAttribute(namedAttribute.getName(), attribute);
- }
-
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
- return failure();
- }
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
-}
-
-// ===========================
-// Aggressive rewrite patterns
-// ===========================
-
-class ConvertGenericOp : public ConversionPattern {
-public:
- ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
- : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (!isa<tosa::TosaOp>(op))
- return rewriter.notifyMatchFailure(
- op,
- "Support for operations other than TOSA has not been implemented.");
-
- return convertGenericOp(op, operands, rewriter, typeConverter);
- }
-};
-
-// ===============================
-// Bounds checked rewrite patterns
-// ===============================
-
-class ConvertArgMaxOpWithBoundsChecking
- : public OpConversionPattern<tosa::ArgMaxOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- // Output type can be narrowed based on the size of the axis dimension
- const int32_t axis = op.getAxis();
- const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
- if (!inputType || !inputType.isStaticDim(axis))
- return rewriter.notifyMatchFailure(
- op, "Requires a static axis dimension for bounds checking.");
- const int64_t axisDim = inputType.getDimSize(axis);
- if (axisDim >= std::numeric_limits<int32_t>::max())
- return rewriter.notifyMatchFailure(
- op, "Axis dimension is too large to narrow safely.");
-
- const Type resultType = op.getOutput().getType();
- const Type newResultType = typeConverter->convertType(resultType);
- rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
- adaptor.getInput(), axis);
- return success();
- }
-};
-
-class ConvertCastOpWithBoundsChecking
- : public OpConversionPattern<tosa::CastOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
- const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
- if (!inputType || !resultType)
- return failure();
-
- const auto elementInputIntType =
- dyn_cast<IntegerType>(inputType.getElementType());
- const auto elementResultIntType =
- dyn_cast<IntegerType>(resultType.getElementType());
- if (elementInputIntType && elementResultIntType &&
- elementInputIntType.getWidth() > elementResultIntType.getWidth())
- return rewriter.notifyMatchFailure(
- op, "Narrowing cast may lead to data loss.");
-
- rewriter.replaceOpWithNewOp<tosa::CastOp>(
- op, typeConverter->convertType(resultType), adaptor.getInput());
- return success();
- }
-};
-
-class ConvertClampOpWithBoundsChecking
- : public OpConversionPattern<tosa::ClampOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ClampOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- const auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
- const auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
- if (!minAttr || !maxAttr)
- return failure();
-
- const int64_t min = minAttr.getInt();
- const int64_t max = maxAttr.getInt();
-
- if (min < std::numeric_limits<int32_t>::min() ||
- max > std::numeric_limits<int32_t>::max())
- return rewriter.notifyMatchFailure(
- op, "Clamp bounds exceed int32 range. Narrowing cast may lead to "
- "data loss.");
-
- const Type resultType = op.getOutput().getType();
- const Type newResultType = typeConverter->convertType(resultType);
-
- const IntegerType int32Type = IntegerType::get(rewriter.getContext(), 32);
- const IntegerAttr newMinAttr =
- rewriter.getIntegerAttr(int32Type, static_cast<int32_t>(min));
- const IntegerAttr newMaxAttr =
- rewriter.getIntegerAttr(int32Type, static_cast<int32_t>(max));
- rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
- adaptor.getInput(), newMinAttr,
- newMaxAttr, op.getNanModeAttr());
- return success();
- }
-};
-
-template <typename OpTy>
-class ConvertTypedOp : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- return convertGenericOp(op, adaptor.getOperands(), rewriter,
- this->getTypeConverter());
- }
-};
-
-struct TosaNarrowI64ToI32
- : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
-public:
- explicit TosaNarrowI64ToI32() = default;
- explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
- : TosaNarrowI64ToI32() {
- this->aggressiveRewrite = options.aggressiveRewrite;
- this->convertFunctionBoundaries = options.convertFunctionBoundaries;
- }
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
-
- TypeConverter typeConverter;
- typeConverter.addConversion([](Type type) -> Type { return type; });
- typeConverter.addConversion([](IntegerType type) -> Type {
- if (!type.isInteger(64))
- return type;
- return IntegerType::get(type.getContext(), 32);
- });
- typeConverter.addConversion(
- [&typeConverter](RankedTensorType type) -> Type {
- const Type elementType = type.getElementType();
- if (!elementType.isInteger(64))
- return type;
- return RankedTensorType::get(type.getShape(),
- typeConverter.convertType(elementType));
- });
-
- const auto materializeCast = [](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) -> Value {
- if (inputs.size() != 1)
- return Value();
- return tosa::CastOp::create(builder, loc, resultType, inputs.front());
- };
- typeConverter.addSourceMaterialization(materializeCast);
- typeConverter.addTargetMaterialization(materializeCast);
-
- typeConverter.addTypeAttributeConversion(
- [](IntegerType type, IntegerAttr attribute) -> Attribute {
- const APInt value = attribute.getValue().truncSSat(32);
- return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
- value);
- });
- typeConverter.addTypeAttributeConversion(
- [&typeConverter](ShapedType type,
- DenseIntElementsAttr attr) -> Attribute {
- const ShapedType newType =
- cast<ShapedType>(typeConverter.convertType(type));
- const auto oldElementType = cast<IntegerType>(type.getElementType());
- const auto newElementType =
- cast<IntegerType>(newType.getElementType());
- if (oldElementType.getWidth() == newElementType.getWidth())
- return attr;
-
- DenseElementsAttr mapped =
- attr.mapValues(newElementType, [&](const APInt &v) {
- return v.truncSSat(newElementType.getWidth());
- });
- return mapped;
- });
-
- ConversionTarget target(*context);
- target.addDynamicallyLegalDialect<tosa::TosaDialect>(
- [&typeConverter](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes()) &&
- typeConverter.isLegal(op->getOperandTypes());
- });
- if (convertFunctionBoundaries) {
- target.addDynamicallyLegalOp<func::FuncOp>(
- [&typeConverter](func::FuncOp op) {
- return typeConverter.isSignatureLegal(op.getFunctionType()) &&
- typeConverter.isLegal(&op.getBody());
- });
- target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
- const FunctionType funcType =
- op->getParentOfType<func::FuncOp>().getFunctionType();
- return llvm::equal(op.getOperandTypes(), funcType.getResults());
- });
- } else {
- target.addDynamicallyLegalOp<func::FuncOp>(
- [](func::FuncOp op) { return true; });
- target.addDynamicallyLegalOp<func::ReturnOp>(
- [](func::ReturnOp op) { return true; });
- }
-
- RewritePatternSet patterns(context);
- if (convertFunctionBoundaries) {
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- patterns, typeConverter);
- populateReturnOpTypeConversionPattern(patterns, typeConverter);
- }
- if (aggressiveRewrite) {
- patterns.add<ConvertGenericOp>(typeConverter, context);
- } else {
- // Tensor
- patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
- // Activation functions
- patterns.add<ConvertClampOpWithBoundsChecking>(typeConverter, context);
- // Data layout
- patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
- // Type conversion
- patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
- // Controlflow
- patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
- }
-
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns))))
- signalPassFailure();
- }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
new file mode 100644
index 0000000000000..d9651f7321269
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
@@ -0,0 +1,691 @@
+//===- TosaNarrowTypes.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TOSA narrowing passes that rewrite tensor element
+// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "llvm/ADT/APFloat.h"
+
+#include <limits>
+#include <type_traits>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// Narrowing mode for this pass.
+enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
+
+// ---------------------------------------------------------------------------
+// Shared helpers
+// ---------------------------------------------------------------------------
+
+template <TosaNarrowKind Kind>
+bool isSourceInteger(IntegerType type) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return type.isInteger(64);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+bool isSourceFloat(FloatType type) {
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return type.isF64();
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+Type convertInteger(IntegerType type) {
+ if (!isSourceInteger<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return IntegerType::get(type.getContext(), 32);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+Type convertFloat(FloatType type) {
+ if (!isSourceFloat<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return Float32Type::get(type.getContext());
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+bool isSourceElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return isSourceInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return isSourceFloat<Kind>(floatTy);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+Type convertElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return convertInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return convertFloat<Kind>(floatTy);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+bool typeNeedsConversion(Type type) {
+ if (auto shaped = dyn_cast<ShapedType>(type))
+ return isSourceElement<Kind>(shaped.getElementType());
+ return isSourceElement<Kind>(type);
+}
+
+FailureOr<APInt> convertIntegerConstant(IntegerType targetType,
+ const APInt &value,
+ bool allowLossyConversion) {
+ const unsigned targetWidth = targetType.getWidth();
+ if (!allowLossyConversion && !value.isSignedIntN(targetWidth))
+ return failure();
+
+ if (allowLossyConversion)
+ return value.truncSSat(targetWidth);
+ return value.sextOrTrunc(targetWidth);
+}
+
+FailureOr<APFloat> convertFloatConstant(FloatType targetType,
+ const APFloat &value,
+ bool allowLossyConversion) {
+ APFloat converted(value);
+ bool losesInfo = false;
+ converted.convert(targetType.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ if (!allowLossyConversion && losesInfo)
+ return failure();
+ return converted;
+}
+
+// Narrows scalar constant attributes so they keep matching the converted
+// element types.
+template <TosaNarrowKind Kind>
+FailureOr<Attribute> tryConvertScalarAttribute(Attribute attribute,
+ bool allowLossyConversion) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
+ intType && isSourceInteger<Kind>(intType)) {
+ const auto convertedType =
+ cast<IntegerType>(convertInteger<Kind>(intType));
+ FailureOr<APInt> convertedValue = convertIntegerConstant(
+ convertedType, intAttr.getValue(), allowLossyConversion);
+ if (failed(convertedValue))
+ return failure();
+ return IntegerAttr::get(convertedType, convertedValue.value());
+ }
+ }
+ } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+ if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
+ if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
+ floatType && isSourceFloat<Kind>(floatType)) {
+ const auto convertedType =
+ cast<FloatType>(convertFloat<Kind>(floatType));
+ FailureOr<APFloat> convertedValue = convertFloatConstant(
+ convertedType, floatAttr.getValue(), allowLossyConversion);
+ if (failed(convertedValue))
+ return failure();
+ return FloatAttr::get(convertedType, convertedValue.value());
+ }
+ }
+ }
+
+ return attribute;
+}
+
+template <TosaNarrowKind Kind>
+FailureOr<Attribute>
+convertDenseIntElementsAttr(ShapedType type, DenseIntElementsAttr attr,
+ const TypeConverter &typeConverter,
+ bool allowLossyConversion) {
+ if constexpr (Kind != TosaNarrowKind::Int64ToInt32)
+ return attr;
+
+ const auto oldElementType = dyn_cast<IntegerType>(type.getElementType());
+ if (!oldElementType || !isSourceInteger<Kind>(oldElementType))
+ return attr;
+
+ const auto newType =
+ dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+ if (!newType)
+ return failure();
+
+ const auto newElementType = dyn_cast<IntegerType>(newType.getElementType());
+ if (!newElementType)
+ return failure();
+
+ if (!allowLossyConversion) {
+ for (APInt value : attr.getValues<APInt>())
+ if (failed(convertIntegerConstant(newElementType, value,
+ /*allowLossyConversion=*/false)))
+ return failure();
+ }
+
+ Attribute convertedAttr =
+ attr.mapValues(newElementType, [&](const APInt &value) -> APInt {
+ return convertIntegerConstant(newElementType, value,
+ /*allowLossyConversion=*/true)
+ .value();
+ });
+ return convertedAttr;
+}
+
+template <TosaNarrowKind Kind>
+FailureOr<Attribute>
+convertDenseFPElementsAttr(ShapedType type, DenseFPElementsAttr attr,
+ const TypeConverter &typeConverter,
+ bool allowLossyConversion) {
+ if constexpr (Kind != TosaNarrowKind::Float64ToFloat32)
+ return attr;
+
+ const auto oldElementType = dyn_cast<FloatType>(type.getElementType());
+ if (!oldElementType || !isSourceFloat<Kind>(oldElementType))
+ return attr;
+
+ const auto newType =
+ dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+ if (!newType)
+ return failure();
+
+ const auto newElementType = dyn_cast<FloatType>(newType.getElementType());
+ if (!newElementType)
+ return failure();
+
+ if (!allowLossyConversion) {
+ for (APFloat value : attr.getValues<APFloat>())
+ if (failed(convertFloatConstant(newElementType, value,
+ /*allowLossyConversion=*/false)))
+ return failure();
+ }
+
+ Attribute convertedAttr =
+ attr.mapValues(newElementType, [&](const APFloat &value) -> APInt {
+ APFloat converted = convertFloatConstant(newElementType, value,
+ /*allowLossyConversion=*/true)
+ .value();
+ // DenseFPElementsAttr stores each float as raw bits, so emit the APInt
+ // representation that MLIR expects in the underlying buffer.
+ return converted.bitcastToAPInt();
+ });
+ return convertedAttr;
+}
+
+template <TosaNarrowKind Kind, typename AttrT>
+FailureOr<Attribute>
+convertAttributeWithTypeConverter(AttrT attr, Type type,
+ const TypeConverter *typeConverter) {
+ if (!typeNeedsConversion<Kind>(type))
+ return attr;
+
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attr);
+ if (!convertedAttribute)
+ return failure();
+
+ return convertedAttribute.value();
+}
+
+// Rejects cast rewrites that would lose precision (unless aggressive mode is
+// enabled).
+template <TosaNarrowKind Kind>
+LogicalResult
+verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
+ ShapedType resultType,
+ ConversionPatternRewriter &rewriter) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+ } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+ const auto elementInputFloatType =
+ dyn_cast<FloatType>(inputType.getElementType());
+ const auto elementResultFloatType =
+ dyn_cast<FloatType>(resultType.getElementType());
+ if (elementInputFloatType && elementResultFloatType &&
+ elementInputFloatType.getIntOrFloatBitWidth() >
+ elementResultFloatType.getIntOrFloatBitWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+ }
+
+ return success();
+}
+
+// ---------------------------------------------------------------------------
+// Conversion patterns
+// ---------------------------------------------------------------------------
+
+// Applies the narrowing TypeConverter to a single TOSA op, including its
+// attributes and nested regions.
+template <TosaNarrowKind Kind>
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter,
+ bool allowLossyConversion) {
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ // Keep attribute payloads consistent with the converted element types.
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ if (isa<IntegerAttr>(attribute) || isa<FloatAttr>(attribute)) {
+ FailureOr<Attribute> convertedAttr =
+ tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+ if (failed(convertedAttr))
+ return rewriter.notifyMatchFailure(
+ op, "Scalar attribute narrowing would lose precision; enable "
+ "aggressive rewrite to override.");
+ state.addAttribute(namedAttribute.getName(), convertedAttr.value());
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ FailureOr<Attribute> convertedAttr =
+ convertAttributeWithTypeConverter<Kind>(typeAttr, typeAttr.getValue(),
+ typeConverter);
+ if (failed(convertedAttr))
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttr.value());
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ FailureOr<Attribute> convertedAttr =
+ convertAttributeWithTypeConverter<Kind>(
+ denseElementsAttr, denseElementsAttr.getType(), typeConverter);
+ if (failed(convertedAttr))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute without precision "
+ "loss; enable aggressive rewrite to override.");
+ state.addAttribute(namedAttribute.getName(), convertedAttr.value());
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+template <TosaNarrowKind Kind>
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context,
+ bool allowLossyConversion)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context),
+ allowLossyConversion(allowLossyConversion) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp<Kind>(op, operands, rewriter, typeConverter,
+ allowLossyConversion);
+ }
+
+private:
+ const bool allowLossyConversion;
+};
+
+template <typename OpTy, TosaNarrowKind Kind>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+public:
+ ConvertTypedOp(TypeConverter &typeConverter, MLIRContext *context)
+ : OpConversionPattern<OpTy>(typeConverter, context) {}
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter(),
+ /*allowLossyConversion=*/false);
+ }
+};
+
+// ---------------------------------------------------------------------------
+// Kind-specific helpers and patterns
+// ---------------------------------------------------------------------------
+
+// Casts get extra checking so we only narrow when it is probably safe.
+template <TosaNarrowKind Kind>
+class ConvertCastOpWithBoundsChecking
+ : public OpConversionPattern<tosa::CastOp> {
+ using OpConversionPattern<tosa::CastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::CastOp op, typename tosa::CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+ if (!inputType || !resultType)
+ return failure();
+
+ const TypeConverter *typeConverter = this->getTypeConverter();
+ if (failed(verifyCastDoesNotLosePrecision<Kind>(op, inputType, resultType,
+ rewriter)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(
+ op, typeConverter->convertType(resultType), adaptor.getInput());
+ return success();
+ }
+};
+
+// ArgMax indices must fit the axis dimension, so we guard the integer rewrite.
+class ConvertArgMaxOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ArgMaxOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ArgMaxOp op, typename tosa::ArgMaxOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const int32_t axis = op.getAxis();
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ if (!inputType || !inputType.isStaticDim(axis))
+ return rewriter.notifyMatchFailure(
+ op, "Requires a static axis dimension for bounds checking.");
+ const int64_t axisDim = inputType.getDimSize(axis);
+ if (axisDim >= std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension is too large to narrow safely.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType =
+ this->getTypeConverter()->convertType(resultType);
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+ adaptor.getInput(), axis);
+ return success();
+ }
+};
+
+template <TosaNarrowKind Kind>
+class ConvertClampOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ClampOp> {
+ static_assert(Kind == TosaNarrowKind::Int64ToInt32,
+ "Clamp bounds checking only supported for integer narrowing");
+ using OpConversionPattern<tosa::ClampOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ClampOp op, typename tosa::ClampOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
+ auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
+ if (!minAttr || !maxAttr)
+ return rewriter.notifyMatchFailure(
+ op, "Clamp attributes must be integer constants.");
+
+ const int64_t min = minAttr.getInt();
+ const int64_t max = maxAttr.getInt();
+ if (min < std::numeric_limits<int32_t>::min() ||
+ max > std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Clamp bounds exceed int32 range. Narrowing may lose data.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType =
+ this->getTypeConverter()->convertType(resultType);
+ const auto newResultShaped = dyn_cast<ShapedType>(newResultType);
+ if (!newResultShaped)
+ return failure();
+ const auto newElementType =
+ dyn_cast<IntegerType>(newResultShaped.getElementType());
+ if (!newElementType)
+ return failure();
+
+ const IntegerAttr newMinAttr = IntegerAttr::get(newElementType, min);
+ const IntegerAttr newMaxAttr = IntegerAttr::get(newElementType, max);
+
+ rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
+ adaptor.getInput(), newMinAttr,
+ newMaxAttr, op.getNanModeAttr());
+ return success();
+ }
+};
+
+// Shared implementation for both narrowing passes; the mode decides which
+// element types and attribute payloads participate.
+template <TosaNarrowKind Kind>
+LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
+ bool convertFunctionBoundaries) {
+ MLIRContext *context = op->getContext();
+ const bool allowLossyConversion = aggressiveRewrite;
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) -> Type { return type; });
+
+ typeConverter.addConversion(
+ [](IntegerType type) -> Type { return convertInteger<Kind>(type); });
+ typeConverter.addConversion(
+ [](FloatType type) -> Type { return convertFloat<Kind>(type); });
+ typeConverter.addConversion([&typeConverter](RankedTensorType type) -> Type {
+ Type elementType = type.getElementType();
+ if (!isSourceElement<Kind>(elementType))
+ return type;
+ Type converted = typeConverter.convertType(elementType);
+ if (!converted || converted == elementType)
+ return type;
+ return RankedTensorType::get(type.getShape(), converted,
+ type.getEncoding());
+ });
+ typeConverter.addConversion(
+ [&typeConverter](UnrankedTensorType type) -> Type {
+ Type elementType = type.getElementType();
+ if (!isSourceElement<Kind>(elementType))
+ return type;
+ Type converted = typeConverter.convertType(elementType);
+ if (!converted || converted == elementType)
+ return type;
+ return UnrankedTensorType::get(converted);
+ });
+
+ const auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+ return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ typeConverter.addTypeAttributeConversion(
+ [allowLossyConversion](IntegerType /*type*/, IntegerAttr attribute)
+ -> TypeConverter::AttributeConversionResult {
+ FailureOr<Attribute> converted =
+ tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+ if (failed(converted))
+ return TypeConverter::AttributeConversionResult::abort();
+ return TypeConverter::AttributeConversionResult::result(
+ converted.value());
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter, allowLossyConversion](ShapedType type,
+ DenseIntElementsAttr attr)
+ -> TypeConverter::AttributeConversionResult {
+ FailureOr<Attribute> converted = convertDenseIntElementsAttr<Kind>(
+ type, attr, typeConverter, allowLossyConversion);
+ if (failed(converted))
+ return TypeConverter::AttributeConversionResult::abort();
+ return TypeConverter::AttributeConversionResult::result(
+ converted.value());
+ });
+ } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+ typeConverter.addTypeAttributeConversion(
+ [allowLossyConversion](FloatType /*type*/, FloatAttr attribute)
+ -> TypeConverter::AttributeConversionResult {
+ FailureOr<Attribute> converted =
+ tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
+ if (failed(converted))
+ return TypeConverter::AttributeConversionResult::abort();
+ return TypeConverter::AttributeConversionResult::result(
+ converted.value());
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter, allowLossyConversion](ShapedType type,
+ DenseFPElementsAttr attr)
+ -> TypeConverter::AttributeConversionResult {
+ FailureOr<Attribute> converted = convertDenseFPElementsAttr<Kind>(
+ type, attr, typeConverter, allowLossyConversion);
+ if (failed(converted))
+ return TypeConverter::AttributeConversionResult::abort();
+ return TypeConverter::AttributeConversionResult::result(
+ converted.value());
+ });
+ }
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+ [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes()) &&
+ typeConverter.isLegal(op->getOperandTypes());
+ });
+ if (convertFunctionBoundaries) {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&typeConverter](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+ const FunctionType funcType =
+ op->getParentOfType<func::FuncOp>().getFunctionType();
+ return llvm::equal(op.getOperandTypes(), funcType.getResults());
+ });
+ } else {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp) { return true; });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [](func::ReturnOp) { return true; });
+ }
+
+ RewritePatternSet patterns(context);
+ if (convertFunctionBoundaries) {
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ }
+ if (aggressiveRewrite) {
+ patterns.add<ConvertGenericOp<Kind>>(typeConverter, context,
+ allowLossyConversion);
+ } else {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+ patterns.add<ConvertClampOpWithBoundsChecking<Kind>>(typeConverter,
+ context);
+ }
+ patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReshapeOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReverseOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::SliceOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TileOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TransposeOp, Kind>>(typeConverter,
+ context);
+ patterns.add<ConvertTypedOp<tosa::IdentityOp, Kind>>(typeConverter,
+ context);
+ patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
+ }
+
+ if (failed(applyFullConversion(op, target, std::move(patterns))))
+ return failure();
+ return success();
+}
+
+// ---------------------------------------------------------------------------
+// Pass adapters that forward to the shared implementation
+// ---------------------------------------------------------------------------
+
+struct TosaNarrowI64ToI32
+ : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+ using Base = tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32>;
+
+ TosaNarrowI64ToI32() = default;
+
+ explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ if (failed(runTosaNarrowing<TosaNarrowKind::Int64ToInt32>(
+ getOperation(), this->aggressiveRewrite,
+ this->convertFunctionBoundaries)))
+ signalPassFailure();
+ }
+};
+
+struct TosaNarrowF64ToF32
+ : public tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32> {
+ using Base = tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32>;
+
+ TosaNarrowF64ToF32() = default;
+
+ explicit TosaNarrowF64ToF32(const TosaNarrowF64ToF32PassOptions &options) {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ if (failed(runTosaNarrowing<TosaNarrowKind::Float64ToFloat32>(
+ getOperation(), this->aggressiveRewrite,
+ this->convertFunctionBoundaries)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
new file mode 100644
index 0000000000000..69547194dee3f
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_f64_add
+// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xf64>, %[[IN1:.*]]: tensor<13x21x3xf64>
+// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xf32>, %[[IN1:.*]]: tensor<13x21x3xf32>
+func.func @test_f64_add(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xf64>) -> tensor<13x21x1xf32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+ // COMMON: %[[ADD:.*]] = tosa.add %{{.*}}, %{{.*}} : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf64>
+ // DEFAULT: return %[[OUT]] : tensor<13x21x3xf64>
+ // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_regions
+// DEFAULT: %[[IN0:.*]]: tensor<1xf64>, %[[IN1:.*]]: tensor<1xf64>
+func.func @test_f64_regions(%arg0: tensor<1xf64>, %arg1: tensor<1xf64>, %arg2: tensor<i1>) -> tensor<1xf64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<1xf64>) -> tensor<1xf32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<1xf64>) -> tensor<1xf32>
+ // COMMON: %[[IF:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf64> {
+ // COMMON: %[[ADD:.*]] = tosa.add %{{.*}}, %{{.*}} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %1 = tosa.add %arg0, %arg1 : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64>
+ tosa.yield %1 : tensor<1xf64>
+ } else {
+ // COMMON: %[[SUB:.*]] = tosa.sub %{{.*}}, %{{.*}} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %1 = tosa.sub %arg0, %arg1 : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64>
+ tosa.yield %1 : tensor<1xf64>
+ }
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF]] : (tensor<1xf32>) -> tensor<1xf64>
+ // DEFAULT: return %[[OUT]] : tensor<1xf64>
+ // FUNCBOUND: return %[[IF]] : tensor<1xf32>
+ return %0 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_convert_input_parameters
+// DEFAULT: %[[IN:.*]]: tensor<1x3xf64>
+// FUNCBOUND: %[[IN:.*]]: tensor<1x3xf32>
+func.func @test_convert_input_parameters(%arg0: tensor<1x3xf64>) -> tensor<1x3xf32> {
+ // DEFAULT: %[[CAST_IN:.*]] = tosa.cast %[[IN]] : (tensor<1x3xf64>) -> tensor<1x3xf32>
+ // DEFAULT: %[[IDENTITY:.*]] = tosa.identity %[[CAST_IN]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+ // FUNCBOUND: %[[IDENTITY:.*]] = tosa.identity %[[IN]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+ %0 = tosa.identity %arg0 : (tensor<1x3xf64>) -> tensor<1x3xf64>
+ // COMMON: %[[TO_F32:.*]] = tosa.cast %[[IDENTITY]] : (tensor<1x3xf32>) -> tensor<1x3xf32>
+ %1 = tosa.cast %0 : (tensor<1x3xf64>) -> tensor<1x3xf32>
+ // DEFAULT: return %[[TO_F32]] : tensor<1x3xf32>
+ // FUNCBOUND: return %[[TO_F32]] : tensor<1x3xf32>
+ return %1 : tensor<1x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const
+func.func @test_f64_const() -> tensor<2xf64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+ %0 = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}> : () -> tensor<2xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xf32>) -> tensor<2xf64>
+ // DEFAULT: return %[[OUT]] : tensor<2xf64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xf32>
+ return %0 : tensor<2xf64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
new file mode 100644
index 0000000000000..1034ee67f65e2
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir
@@ -0,0 +1,180 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-f64-to-f32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_f64_identity_chain
+func.func @test_f64_identity_chain(%arg0: tensor<1xf64>) -> tensor<1xf64> {
+ // DEFAULT: %[[CAST_IN:.*]] = tosa.cast %arg0 : (tensor<1xf64>) -> tensor<1xf32>
+ // DEFAULT: %[[ID1:.*]] = tosa.identity %[[CAST_IN]] : (tensor<1xf32>) -> tensor<1xf32>
+ // FUNCBOUND: %[[ID1:.*]] = tosa.identity %arg0 : (tensor<1xf32>) -> tensor<1xf32>
+ %0 = tosa.identity %arg0 : (tensor<1xf64>) -> tensor<1xf64>
+ // COMMON: %[[ID2:.*]] = tosa.identity %[[ID1]] : (tensor<1xf32>) -> tensor<1xf32>
+ %1 = tosa.identity %0 : (tensor<1xf64>) -> tensor<1xf64>
+ // DEFAULT: %[[CAST_OUT:.*]] = tosa.cast %[[ID2]] : (tensor<1xf32>) -> tensor<1xf64>
+ // DEFAULT: return %[[CAST_OUT]] : tensor<1xf64>
+ // FUNCBOUND: return %[[ID2]] : tensor<1xf32>
+ return %1 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const
+func.func @test_f64_const() -> tensor<2xf64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32>
+ %0 = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}> : () -> tensor<2xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xf32>) -> tensor<2xf64>
+ // DEFAULT: return %[[OUT]] : tensor<2xf64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xf32>
+ return %0 : tensor<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const_precision_loss
+func.func @test_f64_const_precision_loss() -> tensor<1xf64> {
+ // expected-error @+2 {{failed to legalize operation 'tosa.const'}}
+ // 2^24 + 1 fits in f64 but rounds to 2^24 in f32.
+ %0 = "tosa.const"() <{values = dense<16777217.0> : tensor<1xf64>}> : () -> tensor<1xf64>
+ return %0 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_const_precision_loss_small
+func.func @test_f64_const_precision_loss_small() -> tensor<1xf64> {
+ // expected-error @+2 {{failed to legalize operation 'tosa.const'}}
+ // Too small: underflows to zero when narrowed to f32.
+ %0 = "tosa.const"() <{values = dense<1.0e-46> : tensor<1xf64>}> : () -> tensor<1xf64>
+ return %0 : tensor<1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_concat
+// DEFAULT: %[[A0:.*]]: tensor<13x21x3xf64>, %[[A1:.*]]: tensor<13x21x3xf64>
+// FUNCBOUND: %[[A0:.*]]: tensor<13x21x3xf32>, %[[A1:.*]]: tensor<13x21x3xf32>
+func.func @test_f64_concat(%arg0: tensor<13x21x3xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<26x21x3xf64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[A0]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[A1]] : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+ // COMMON: %[[CONCAT:.*]] = tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xf64>, tensor<13x21x3xf64>) -> tensor<26x21x3xf64>
+ // DEFAULT: %[[CAST_OUT:.*]] = tosa.cast %[[CONCAT]] : (tensor<26x21x3xf32>) -> tensor<26x21x3xf64>
+ // DEFAULT: return %[[CAST_OUT]] : tensor<26x21x3xf64>
+ // FUNCBOUND: return %[[CONCAT]] : tensor<26x21x3xf32>
+ return %0 : tensor<26x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_pad
+func.func @test_f64_pad(%arg0: tensor<13x21x3xf64>, %arg1: tensor<1xf64>) -> tensor<15x23x5xf64> {
+ %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // DEFAULT-DAG: %[[IN_CAST:.*]] = tosa.cast %arg0 : (tensor<13x21x3xf64>) -> tensor<13x21x3xf32>
+ // DEFAULT-DAG: %[[PAD_CAST:.*]] = tosa.cast %arg1 : (tensor<1xf64>) -> tensor<1xf32>
+ // COMMON: %[[PAD:.*]] = tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<15x23x5xf32>
+ %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xf64>, !tosa.shape<6>, tensor<1xf64>) -> tensor<15x23x5xf64>
+ // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[PAD]] : (tensor<15x23x5xf32>) -> tensor<15x23x5xf64>
+ // DEFAULT: return %[[OUT_CAST]] : tensor<15x23x5xf64>
+ // FUNCBOUND: return %[[PAD]] : tensor<15x23x5xf32>
+ return %1 : tensor<15x23x5xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_reshape
+func.func @test_f64_reshape(%arg0: tensor<13x21x3xf64>) -> tensor<1x819xf64> {
+ %shape = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // COMMON: %[[RESHAPE:.*]] = tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32>
+ %0 = tosa.reshape %arg0, %shape : (tensor<13x21x3xf64>, !tosa.shape<2>) -> tensor<1x819xf64>
+ // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[RESHAPE]] : (tensor<1x819xf32>) -> tensor<1x819xf64>
+ // DEFAULT: return %[[OUT_CAST]] : tensor<1x819xf64>
+ // FUNCBOUND: return %[[RESHAPE]] : tensor<1x819xf32>
+ return %0 : tensor<1x819xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_reverse
+func.func @test_f64_reverse(%arg0: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+ // COMMON: %[[REV:.*]] = tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[REV]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf64>
+ // DEFAULT: return %[[OUT]] : tensor<13x21x3xf64>
+ // FUNCBOUND: return %[[REV]] : tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_slice
+func.func @test_f64_slice(%arg0: tensor<13x21x3xf64>) -> tensor<4x11x1xf64> {
+ %size = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %start = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // COMMON: %[[SLICE:.*]] = tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32>
+ %0 = tosa.slice %arg0, %size, %start : (tensor<13x21x3xf64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[SLICE]] : (tensor<4x11x1xf32>) -> tensor<4x11x1xf64>
+ // DEFAULT: return %[[OUT]] : tensor<4x11x1xf64>
+ // FUNCBOUND: return %[[SLICE]] : tensor<4x11x1xf32>
+ return %0 : tensor<4x11x1xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_tile
+func.func @test_f64_tile(%arg0: tensor<13x21x3xf64>) -> tensor<39x21x6xf64> {
+ %multipliers = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ // COMMON: %[[TILE:.*]] = tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
+ %0 = tosa.tile %arg0, %multipliers : (tensor<13x21x3xf64>, !tosa.shape<3>) -> tensor<39x21x6xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[TILE]] : (tensor<39x21x6xf32>) -> tensor<39x21x6xf64>
+ // DEFAULT: return %[[OUT]] : tensor<39x21x6xf64>
+ // FUNCBOUND: return %[[TILE]] : tensor<39x21x6xf32>
+ return %0 : tensor<39x21x6xf64>
+}
+
+// -----
+
+// CHECK-LABEL: test_f64_transpose
+func.func @test_f64_transpose(%arg0: tensor<13x21x3xf64>) -> tensor<3x13x21xf64> {
+ // COMMON: %[[TRANSPOSE:.*]] = tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf32>) -> tensor<3x13x21xf32>
+ %0 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xf64>) -> tensor<3x13x21xf64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[TRANSPOSE]] : (tensor<3x13x21xf32>) -> tensor<3x13x21xf64>
+ // DEFAULT: return %[[OUT]] : tensor<3x13x21xf64>
+ // FUNCBOUND: return %[[TRANSPOSE]] : tensor<3x13x21xf32>
+ return %0 : tensor<3x13x21xf64>
+}
+
+// -----
+
+module {
+// CHECK-LABEL: test_f64_regions
+func.func @test_f64_regions(%arg0: tensor<1xf64>, %arg1: tensor<1xf64>, %arg2: tensor<i1>) -> tensor<1xf64> {
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xf64> {
+ // COMMON: %[[ID0:.*]] = tosa.identity %{{.*}} : (tensor<1xf32>) -> tensor<1xf32>
+ %1 = tosa.identity %arg0 : (tensor<1xf64>) -> tensor<1xf64>
+ // COMMON: tosa.yield %[[ID0]] : tensor<1xf32>
+ tosa.yield %1 : tensor<1xf64>
+ } else {
+ // COMMON: %[[ID1:.*]] = tosa.identity %{{.*}} : (tensor<1xf32>) -> tensor<1xf32>
+ %1 = tosa.identity %arg1 : (tensor<1xf64>) -> tensor<1xf64>
+ // COMMON: tosa.yield %[[ID1]] : tensor<1xf32>
+ tosa.yield %1 : tensor<1xf64>
+ }
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<1xf32>) -> tensor<1xf64>
+ // DEFAULT: return %[[OUT]] : tensor<1xf64>
+ // FUNCBOUND: return %[[IF_RESULT]] : tensor<1xf32>
+ return %0 : tensor<1xf64>
+}
+}
+
+// -----
+
+module {
+// CHECK-LABEL: test_f64_add_diagnostic
+func.func @test_f64_add_diagnostic(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xf64>
+ return %0 : tensor<13x21x3xf64>
+}
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
index 16265bd21a0d9..42e63346d8c33 100644
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -15,6 +15,18 @@ func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi
// -----
+// CHECK-LABEL: test_i64_const
+func.func @test_i64_const() -> tensor<2xi64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
+ // DEFAULT: return %[[OUT]] : tensor<2xi64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
+ return %0 : tensor<2xi64>
+}
+
+// -----
+
// CHECK-LABEL: test_i64_argmax_cast
func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> {
// COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
More information about the Mlir-commits
mailing list