[Mlir-commits] [mlir] [mlir][tosa] Extend narrowing pass (PR #170712)
Luke Hutton
llvmlistbot at llvm.org
Tue Dec 9 07:55:59 PST 2025
================
@@ -0,0 +1,558 @@
+//===- TosaNarrowTypes.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TOSA narrowing passes that rewrite tensor element
+// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "llvm/ADT/APFloat.h"
+
+#include <limits>
+#include <type_traits>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// Narrowing mode for this pass.
+enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
+
+// ---------------------------------------------------------------------------
+// Shared helpers
+// ---------------------------------------------------------------------------
+
+template <TosaNarrowKind Kind>
+static bool isSourceInteger(IntegerType type) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return type.isInteger(64);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceFloat(FloatType type) {
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return type.isF64();
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertInteger(IntegerType type) {
+ if (!isSourceInteger<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return IntegerType::get(type.getContext(), 32);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertFloat(FloatType type) {
+ if (!isSourceFloat<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return Float32Type::get(type.getContext());
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return isSourceInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return isSourceFloat<Kind>(floatTy);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return convertInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return convertFloat<Kind>(floatTy);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool typeNeedsConversion(Type type) {
+ if (auto shaped = dyn_cast<ShapedType>(type))
+ return isSourceElement<Kind>(shaped.getElementType());
+ return isSourceElement<Kind>(type);
+}
+
+// Narrows scalar constant attributes so they keep matching the converted
+// element types.
+template <TosaNarrowKind Kind>
+static bool tryConvertScalarAttribute(Attribute attribute,
+ Attribute &resultAttr) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
+ intType && isSourceInteger<Kind>(intType)) {
+ const auto convertedType =
+ cast<IntegerType>(convertInteger<Kind>(intType));
+ const APInt truncated =
+ intAttr.getValue().truncSSat(convertedType.getWidth());
+ resultAttr = IntegerAttr::get(convertedType, truncated);
+ return true;
+ }
+ }
+ }
+
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+ if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
+ if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
+ floatType && isSourceFloat<Kind>(floatType)) {
+ const auto convertedType =
+ cast<FloatType>(convertFloat<Kind>(floatType));
+ APFloat value = floatAttr.getValue();
+ bool losesInfo = false;
+ value.convert(convertedType.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ resultAttr = FloatAttr::get(convertedType, value);
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+template <TosaNarrowKind Kind, typename AttrT>
+static LogicalResult
+convertAttributeWithTypeConverter(AttrT attr, Type type,
+ const TypeConverter *typeConverter,
+ Attribute &resultAttr) {
+ if (!typeNeedsConversion<Kind>(type)) {
+ resultAttr = attr;
+ return success();
+ }
+
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attr);
+ if (!convertedAttribute)
+ return failure();
+
+ resultAttr = convertedAttribute.value();
+ return success();
+}
+
+// Rejects cast rewrites that would lose precision (unless aggressive mode is
+// enabled).
+template <TosaNarrowKind Kind>
+static LogicalResult
+verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
+ ShapedType resultType,
+ ConversionPatternRewriter &rewriter) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+ }
+
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+ const auto elementInputFloatType =
+ dyn_cast<FloatType>(inputType.getElementType());
+ const auto elementResultFloatType =
+ dyn_cast<FloatType>(resultType.getElementType());
+ if (elementInputFloatType && elementResultFloatType &&
+ elementInputFloatType.getIntOrFloatBitWidth() >
+ elementResultFloatType.getIntOrFloatBitWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+ }
+
+ return success();
+}
+
+template <TosaNarrowKind Kind, typename DenseAttrT, typename ElementTypeT,
+ typename MapValueFn>
+static Attribute convertDenseElementsAttr(ShapedType type, DenseAttrT attr,
+ const TypeConverter &typeConverter,
+ MapValueFn &&mapValueFn) {
+ const auto oldElementType = dyn_cast<ElementTypeT>(type.getElementType());
+ if (!oldElementType)
+ return attr;
+
+ if constexpr (std::is_same_v<ElementTypeT, IntegerType>) {
+ if (!isSourceInteger<Kind>(oldElementType))
+ return attr;
+ } else {
+ if (!isSourceFloat<Kind>(oldElementType))
+ return attr;
+ }
+
+ const auto newType =
+ dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
+ if (!newType)
+ return attr;
+
+ const auto newElementType = dyn_cast<ElementTypeT>(newType.getElementType());
+ if (!newElementType)
+ return attr;
+
+ return attr.mapValues(newElementType, [&](const auto &value) {
+ return mapValueFn(newElementType, value);
+ });
+}
+
+// ---------------------------------------------------------------------------
+// Conversion patterns
+// ---------------------------------------------------------------------------
+
+// Applies the narrowing TypeConverter to a single TOSA op, including its
+// attributes and nested regions.
+template <TosaNarrowKind Kind>
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter) {
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ // Keep attribute payloads consistent with the converted element types.
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ Attribute convertedAttr;
+ if (tryConvertScalarAttribute<Kind>(attribute, convertedAttr)) {
+ state.addAttribute(namedAttribute.getName(), convertedAttr);
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ if (failed(convertAttributeWithTypeConverter<Kind>(
+ typeAttr, typeAttr.getValue(), typeConverter, convertedAttr)))
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttr);
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ if (failed(convertAttributeWithTypeConverter<Kind>(
+ denseElementsAttr, denseElementsAttr.getType(), typeConverter,
+ convertedAttr)))
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttr);
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+template <TosaNarrowKind Kind>
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp<Kind>(op, operands, rewriter, typeConverter);
+ }
+};
+
+template <typename OpTy, TosaNarrowKind Kind>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter());
+ }
+};
+
+// ---------------------------------------------------------------------------
+// Kind-specific helpers and patterns
+// ---------------------------------------------------------------------------
+
+// Casts get extra checking so we only narrow when it is provably safe.
----------------
lhutton1 wrote:
nit: s/provably/probably
https://github.com/llvm/llvm-project/pull/170712
More information about the Mlir-commits
mailing list