[Mlir-commits] [mlir] [mlir][tosa] Extend narrowing pass (PR #170712)
Luke Hutton
llvmlistbot at llvm.org
Tue Dec 9 07:55:59 PST 2025
================
@@ -0,0 +1,558 @@
+//===- TosaNarrowTypes.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the TOSA narrowing passes that rewrite tensor element
+// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "llvm/ADT/APFloat.h"
+
+#include <limits>
+#include <type_traits>
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// Narrowing mode for this pass.
+enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
+
+// ---------------------------------------------------------------------------
+// Shared helpers
+// ---------------------------------------------------------------------------
+
+template <TosaNarrowKind Kind>
+static bool isSourceInteger(IntegerType type) {
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return type.isInteger(64);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceFloat(FloatType type) {
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return type.isF64();
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertInteger(IntegerType type) {
+ if (!isSourceInteger<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
+ return IntegerType::get(type.getContext(), 32);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertFloat(FloatType type) {
+ if (!isSourceFloat<Kind>(type))
+ return type;
+ if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
+ return Float32Type::get(type.getContext());
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool isSourceElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return isSourceInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return isSourceFloat<Kind>(floatTy);
+ return false;
+}
+
+template <TosaNarrowKind Kind>
+static Type convertElement(Type type) {
+ if (auto intTy = dyn_cast<IntegerType>(type))
+ return convertInteger<Kind>(intTy);
+ if (auto floatTy = dyn_cast<FloatType>(type))
+ return convertFloat<Kind>(floatTy);
+ return type;
+}
+
+template <TosaNarrowKind Kind>
+static bool typeNeedsConversion(Type type) {
+ if (auto shaped = dyn_cast<ShapedType>(type))
+ return isSourceElement<Kind>(shaped.getElementType());
+ return isSourceElement<Kind>(type);
+}
+
+// Narrows scalar constant attributes so they keep matching the converted
+// element types.
+template <TosaNarrowKind Kind>
+static bool tryConvertScalarAttribute(Attribute attribute,
----------------
lhutton1 wrote:
Rather than return `bool`, it can be slightly cleaner to return `FailureOr<Attribute>`. That way the check for success is a little more explicit: `if (succeeded(tryConvertScalarAttribute(...)))`, and we can directly return `resultAttr` here.
https://github.com/llvm/llvm-project/pull/170712
More information about the Mlir-commits
mailing list