[Mlir-commits] [mlir] [mlir][tosa] Extend narrowing pass (PR #170712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 4 10:14:05 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Vitalii Shutov (Lallapallooza)

<details>
<summary>Changes</summary>

- unify the i64->i32 and f64->f32 narrowing logic inside the shared implementation
- register tosa::ConstOp in the non-aggressive rewrite set so standalone constants are narrowed

---

Patch is 50.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170712.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+25) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1-1) 
- (removed) mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp (-310) 
- (added) mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp (+558) 
- (added) mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32-aggressive.mlir (+70) 
- (added) mlir/test/Dialect/Tosa/tosa-narrow-f64-to-f32.mlir (+160) 
- (modified) mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir (+12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 12f520297b702..4a5f283bc66c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -198,4 +198,29 @@ def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
   ];
 }
 
+def TosaNarrowF64ToF32Pass : Pass<"tosa-narrow-f64-to-f32", "func::FuncOp"> {
+  let summary = "Narrow F64 TOSA operations to F32";
+  let description = [{
+    This pass narrows TOSA operations with 64-bit floating-point tensor types to
+    32-bit floating-point tensor types. While TOSA itself has no double
+    precision support, upstream conversions or frontends may still materialize
+    F64 tensors temporarily, so this pass removes them before handing off to a
+    TOSA backend.
+  }];
+
+  let options = [
+    Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
+      "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
+      "is safe. This option may lead to data loss if not used carefully.">,
+    Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
+      "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+      "be inserted at the I/O boundaries.">
+  ];
+
+  let dependentDialects = [
+    "func::FuncDialect",
+    "tosa::TosaDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 091b481d6394b..0ff68b1bb54f4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -13,7 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaTypeConverters.cpp
   TosaProfileCompliance.cpp
   TosaValidation.cpp
-  TosaNarrowI64ToI32.cpp
+  TosaNarrowTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
deleted file mode 100644
index ddaf7d8a5e033..0000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
+++ /dev/null
@@ -1,310 +0,0 @@
-//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass narrows TOSA operations with 64-bit integer tensor types to
-// 32-bit integer tensor types. This can be useful for backends that do not
-// support the EXT-INT64 extension of TOSA. The pass has two options:
-//
-// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
-//     regardless or whether the narrowing is safe. This option may lead to
-//     data loss if not used carefully.
-// - convert-function-boundaries - If enabled, the pass will convert function
-//     I/O types as well. Otherwise casts will be inserted at the I/O
-//     boundaries.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/IR/Verifier.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace tosa {
-#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
-#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
-} // namespace tosa
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-LogicalResult convertGenericOp(Operation *op, ValueRange operands,
-                               ConversionPatternRewriter &rewriter,
-                               const TypeConverter *typeConverter) {
-  // Convert types of results
-  SmallVector<Type, 4> newResults;
-  if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
-    return failure();
-
-  // Create a new operation state
-  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                       newResults, {}, op->getSuccessors());
-
-  for (const NamedAttribute &namedAttribute : op->getAttrs()) {
-    const Attribute attribute = namedAttribute.getValue();
-
-    // Convert integer attribute type
-    if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
-      Type type = typeAttr.getValue();
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(type, attribute);
-      if (!convertedAttribute)
-        return rewriter.notifyMatchFailure(op,
-                                           "Failed to convert type attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
-      const Type type = denseElementsAttr.getType();
-      const std::optional<Attribute> convertedAttribute =
-          typeConverter->convertTypeAttribute(type, denseElementsAttr);
-      if (!convertedAttribute)
-        return rewriter.notifyMatchFailure(
-            op, "Failed to convert dense elements attribute.");
-      state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
-      continue;
-    }
-
-    state.addAttribute(namedAttribute.getName(), attribute);
-  }
-
-  for (Region &region : op->getRegions()) {
-    Region *newRegion = state.addRegion();
-    rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-    if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
-      return failure();
-  }
-
-  Operation *newOp = rewriter.create(state);
-  rewriter.replaceOp(op, newOp->getResults());
-  return success();
-}
-
-// ===========================
-// Aggressive rewrite patterns
-// ===========================
-
-class ConvertGenericOp : public ConversionPattern {
-public:
-  ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
-      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    if (!isa<tosa::TosaOp>(op))
-      return rewriter.notifyMatchFailure(
-          op,
-          "Support for operations other than TOSA has not been implemented.");
-
-    return convertGenericOp(op, operands, rewriter, typeConverter);
-  }
-};
-
-// ===============================
-// Bounds checked rewrite patterns
-// ===============================
-
-class ConvertArgMaxOpWithBoundsChecking
-    : public OpConversionPattern<tosa::ArgMaxOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    // Output type can be narrowed based on the size of the axis dimension
-    const int32_t axis = op.getAxis();
-    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
-    if (!inputType || !inputType.isStaticDim(axis))
-      return rewriter.notifyMatchFailure(
-          op, "Requires a static axis dimension for bounds checking.");
-    const int64_t axisDim = inputType.getDimSize(axis);
-    if (axisDim >= std::numeric_limits<int32_t>::max())
-      return rewriter.notifyMatchFailure(
-          op, "Axis dimension is too large to narrow safely.");
-
-    const Type resultType = op.getOutput().getType();
-    const Type newResultType = typeConverter->convertType(resultType);
-    rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
-                                                adaptor.getInput(), axis);
-    return success();
-  }
-};
-
-class ConvertCastOpWithBoundsChecking
-    : public OpConversionPattern<tosa::CastOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
-    const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
-    if (!inputType || !resultType)
-      return failure();
-
-    const auto elementInputIntType =
-        dyn_cast<IntegerType>(inputType.getElementType());
-    const auto elementResultIntType =
-        dyn_cast<IntegerType>(resultType.getElementType());
-    if (elementInputIntType && elementResultIntType &&
-        elementInputIntType.getWidth() > elementResultIntType.getWidth())
-      return rewriter.notifyMatchFailure(
-          op, "Narrowing cast may lead to data loss.");
-
-    rewriter.replaceOpWithNewOp<tosa::CastOp>(
-        op, typeConverter->convertType(resultType), adaptor.getInput());
-    return success();
-  }
-};
-
-template <typename OpTy>
-class ConvertTypedOp : public OpConversionPattern<OpTy> {
-  using OpConversionPattern<OpTy>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    return convertGenericOp(op, adaptor.getOperands(), rewriter,
-                            this->getTypeConverter());
-  }
-};
-
-struct TosaNarrowI64ToI32
-    : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
-public:
-  explicit TosaNarrowI64ToI32() = default;
-  explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
-      : TosaNarrowI64ToI32() {
-    this->aggressiveRewrite = options.aggressiveRewrite;
-    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
-  }
-
-  void runOnOperation() override {
-    MLIRContext *context = &getContext();
-
-    TypeConverter typeConverter;
-    typeConverter.addConversion([](Type type) -> Type { return type; });
-    typeConverter.addConversion([](IntegerType type) -> Type {
-      if (!type.isInteger(64))
-        return type;
-      return IntegerType::get(type.getContext(), 32);
-    });
-    typeConverter.addConversion(
-        [&typeConverter](RankedTensorType type) -> Type {
-          const Type elementType = type.getElementType();
-          if (!elementType.isInteger(64))
-            return type;
-          return RankedTensorType::get(type.getShape(),
-                                       typeConverter.convertType(elementType));
-        });
-
-    const auto materializeCast = [](OpBuilder &builder, Type resultType,
-                                    ValueRange inputs, Location loc) -> Value {
-      if (inputs.size() != 1)
-        return Value();
-      return tosa::CastOp::create(builder, loc, resultType, inputs.front());
-    };
-    typeConverter.addSourceMaterialization(materializeCast);
-    typeConverter.addTargetMaterialization(materializeCast);
-
-    typeConverter.addTypeAttributeConversion(
-        [](IntegerType type, IntegerAttr attribute) -> Attribute {
-          const APInt value = attribute.getValue().truncSSat(32);
-          return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
-                                  value);
-        });
-    typeConverter.addTypeAttributeConversion(
-        [&typeConverter](ShapedType type,
-                         DenseIntElementsAttr attr) -> Attribute {
-          const ShapedType newType =
-              cast<ShapedType>(typeConverter.convertType(type));
-          const auto oldElementType = cast<IntegerType>(type.getElementType());
-          const auto newElementType =
-              cast<IntegerType>(newType.getElementType());
-          if (oldElementType.getWidth() == newElementType.getWidth())
-            return attr;
-
-          DenseElementsAttr mapped =
-              attr.mapValues(newElementType, [&](const APInt &v) {
-                return v.truncSSat(newElementType.getWidth());
-              });
-          return mapped;
-        });
-
-    ConversionTarget target(*context);
-    target.addDynamicallyLegalDialect<tosa::TosaDialect>(
-        [&typeConverter](Operation *op) {
-          return typeConverter.isLegal(op->getResultTypes()) &&
-                 typeConverter.isLegal(op->getOperandTypes());
-        });
-    if (convertFunctionBoundaries) {
-      target.addDynamicallyLegalOp<func::FuncOp>(
-          [&typeConverter](func::FuncOp op) {
-            return typeConverter.isSignatureLegal(op.getFunctionType()) &&
-                   typeConverter.isLegal(&op.getBody());
-          });
-      target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
-        const FunctionType funcType =
-            op->getParentOfType<func::FuncOp>().getFunctionType();
-        return llvm::equal(op.getOperandTypes(), funcType.getResults());
-      });
-    } else {
-      target.addDynamicallyLegalOp<func::FuncOp>(
-          [](func::FuncOp op) { return true; });
-      target.addDynamicallyLegalOp<func::ReturnOp>(
-          [](func::ReturnOp op) { return true; });
-    }
-
-    RewritePatternSet patterns(context);
-    if (convertFunctionBoundaries) {
-      populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-          patterns, typeConverter);
-      populateReturnOpTypeConversionPattern(patterns, typeConverter);
-    }
-    if (aggressiveRewrite) {
-      patterns.add<ConvertGenericOp>(typeConverter, context);
-    } else {
-      // Tensor
-      patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
-      // Data layout
-      patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
-      // Type conversion
-      patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
-      // Controlflow
-      patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
-      patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
-    }
-
-    if (failed(
-            applyFullConversion(getOperation(), target, std::move(patterns))))
-      signalPassFailure();
-  }
-};
-
-} // namespace
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
new file mode 100644
index 0000000000000..0895bb3905d67
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowTypes.cpp
@@ -0,0 +1,558 @@
+//===- TosaNarrowTypes.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TOSA narrowing passes that rewrite tensor element
+// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "llvm/ADT/APFloat.h"
+
+#include <limits>
+#include <type_traits>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// Narrowing mode for this pass.
+enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
+
+// ---------------------------------------------------------------------------
+// Shared helpers
+// ---------------------------------------------------------------------------
+
+template <TosaNarrowKind Kind>
+static bool isSourceInteger(IntegerType type) {
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+    return type.isInteger(64);
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceFloat(FloatType type) {
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+    return type.isF64();
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertInteger(IntegerType type) {
+  if (!isSourceInteger<Kind>(type))
+    return type;
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+    return IntegerType::get(type.getContext(), 32);
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertFloat(FloatType type) {
+  if (!isSourceFloat<Kind>(type))
+    return type;
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+    return Float32Type::get(type.getContext());
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceElement(Type type) {
+  if (auto intTy = dyn_cast<IntegerType>(type))
+    return isSourceInteger<Kind>(intTy);
+  if (auto floatTy = dyn_cast<FloatType>(type))
+    return isSourceFloat<Kind>(floatTy);
+  return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertElement(Type type) {
+  if (auto intTy = dyn_cast<IntegerType>(type))
+    return convertInteger<Kind>(intTy);
+  if (auto floatTy = dyn_cast<FloatType>(type))
+    return convertFloat<Kind>(floatTy);
+  return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool typeNeedsConversion(Type type) {
+  if (auto shaped = dyn_cast<ShapedType>(type))
+    return isSourceElement<Kind>(shaped.getElementType());
+  return isSourceElement<Kind>(type);
+}
+
+// Narrows scalar constant attributes so they keep matching the converted
+// element types.
+template <TosaNarrowKind Kind>
+static bool tryConvertScalarAttribute(Attribute attribute,
+                                      Attribute &resultAttr) {
+  if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
+    if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+      if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
+          intType && isSourceInteger<Kind>(intType)) {
+        const auto convertedType =
+            cast<IntegerType>(convertInteger<Kind>(intType));
+        const APInt truncated =
+            intAttr.getValue().truncSSat(convertedType.getWidth());
+        resultAttr = IntegerAttr::get(convertedType, truncated);
+        return true;
+      }
+    }
+  }
+
+  if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
+    if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
+      if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
+          floatType && isSourceFloat<Kind>(floatType)) {
+        const auto convertedType =
+            cast<FloatType>(convertFloat<Kind>(floatType));
+        APFloat value = floatAttr.getValue();
+        bool losesInfo = false;
+        value.convert(convertedType.getFloatSemantics(),
+                      APFloat::rmNearestTiesToEven, &losesInfo);
+        resultAttr = FloatAttr::get(convertedType, value);
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+template <TosaNarrowKind Kind, typename AttrT>
+static LogicalResult
+convertAttributeWithTypeConverter(AttrT attr, Type type,
+                                  const TypeConverter *typeConverter,
+                                  Attribute &resultAttr) {
+  if (!typeNeedsConversion<Kind>(type)) {
+    resultAttr = attr;
+    return success();
+  }
+
+  cons...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170712


More information about the Mlir-commits mailing list