[Mlir-commits] [mlir] [mlir][tosa] Add a pass to narrow i64 to i32 (PR #165581)
Luke Hutton
llvmlistbot at llvm.org
Wed Oct 29 08:24:30 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/165581
This pass aims to narrow i64 types on TOSA operations to i32. It comes with the following options:
- "aggressive-rewrite" - This option is typically able to narrow more values, but may impact numerical behaviour if not used carefully.
- "convert-function-boundaries" - If enabled, parameters/ results to/from a function may be narrowed. Otherwise, casts are inserted to preserve the I/O of the function.
Currently the non aggressive mode is very limited, targeting an argmax -> cast sequence that has been observed during legalization as well as some data layout operations that can always narrow. Support for more operations will be added in the future.
Co-authored-by: Vitalii Shutov <vitalii.shutov at arm.com>
Co-authored-by: Shubham <shubham at arm.com>
Co-authored-by: Declan Flavin <declan.flavin at arm.com>
>From 36e6bd5f68548f80cc2ecec6dd0691ba68630ba6 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 14 Oct 2025 09:36:26 +0100
Subject: [PATCH] [mlir][tosa] Add a pass to narrow i64 to i32
This pass aims to narrow i64 types on TOSA operations
to i32. It comes with the following options:
- "aggressive-rewrite" - This option is typically able to
narrow more values, but may impact numerical behaviour
if not used carefully.
- "convert-function-boundaries" - If enabled, parameters/
results to/from a function may be narrowed. Otherwise,
casts are inserted to preserve the I/O of the function.
Currently the non aggressive mode is very limited, targeting
an argmax -> cast sequence that has been observed during
legalization as well as some data layout operations that can
always narrow. Support for more operations will be added in
the future.
Co-authored-by: Vitalii Shutov <vitalii.shutov at arm.com>
Co-authored-by: Shubham <shubham at arm.com>
Co-authored-by: Declan Flavin <declan.flavin at arm.com>
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: Ia8a766c88e6f8e8d019bce6c47114ce3b8a06969
---
.../mlir/Dialect/Tosa/Transforms/Passes.td | 23 ++
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 +
.../Tosa/Transforms/TosaNarrowI64ToI32.cpp | 310 ++++++++++++++++++
.../tosa-narrow-i64-to-i32-aggressive.mlir | 81 +++++
.../Dialect/Tosa/tosa-narrow-i64-to-i32.mlir | 162 +++++++++
5 files changed, 577 insertions(+)
create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
create mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
create mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 14b00b04ccc18..420e58192b8fd 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -166,4 +166,27 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
}
+def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
+ let summary = "Narrow I64 TOSA operations to I32";
+ let description = [{
+ This pass narrows TOSA operations with 64-bit integer tensor types to
+ 32-bit integer tensor types. This can be useful for backends that do not
+ support the EXT-INT64 extension of TOSA.
+ }];
+
+ let options = [
+ Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
+ "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
+ "is safe. This option may lead to data loss if not used carefully.">,
+ Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
+ "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+ "be inserted at the I/O boundaries.">
+ ];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 41b338d6e7189..987ce4ed870c9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
+ TosaNarrowI64ToI32.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
new file mode 100644
index 0000000000000..ddaf7d8a5e033
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
@@ -0,0 +1,310 @@
+//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass narrows TOSA operations with 64-bit integer tensor types to
+// 32-bit integer tensor types. This can be useful for backends that do not
+// support the EXT-INT64 extension of TOSA. The pass has two options:
+//
+// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
+// regardless or whether the narrowing is safe. This option may lead to
+// data loss if not used carefully.
+// - convert-function-boundaries - If enabled, the pass will convert function
+// I/O types as well. Otherwise casts will be inserted at the I/O
+// boundaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+LogicalResult convertGenericOp(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter,
+ const TypeConverter *typeConverter) {
+ // Convert types of results
+ SmallVector<Type, 4> newResults;
+ if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
+ return failure();
+
+ // Create a new operation state
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, {}, op->getSuccessors());
+
+ for (const NamedAttribute &namedAttribute : op->getAttrs()) {
+ const Attribute attribute = namedAttribute.getValue();
+
+ // Convert integer attribute type
+ if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
+ Type type = typeAttr.getValue();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, attribute);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(op,
+ "Failed to convert type attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
+ const Type type = denseElementsAttr.getType();
+ const std::optional<Attribute> convertedAttribute =
+ typeConverter->convertTypeAttribute(type, denseElementsAttr);
+ if (!convertedAttribute)
+ return rewriter.notifyMatchFailure(
+ op, "Failed to convert dense elements attribute.");
+ state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
+ continue;
+ }
+
+ state.addAttribute(namedAttribute.getName(), attribute);
+ }
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
+ return failure();
+ }
+
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+}
+
+// ===========================
+// Aggressive rewrite patterns
+// ===========================
+
+class ConvertGenericOp : public ConversionPattern {
+public:
+ ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
+ : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!isa<tosa::TosaOp>(op))
+ return rewriter.notifyMatchFailure(
+ op,
+ "Support for operations other than TOSA has not been implemented.");
+
+ return convertGenericOp(op, operands, rewriter, typeConverter);
+ }
+};
+
+// ===============================
+// Bounds checked rewrite patterns
+// ===============================
+
+class ConvertArgMaxOpWithBoundsChecking
+ : public OpConversionPattern<tosa::ArgMaxOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // Output type can be narrowed based on the size of the axis dimension
+ const int32_t axis = op.getAxis();
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ if (!inputType || !inputType.isStaticDim(axis))
+ return rewriter.notifyMatchFailure(
+ op, "Requires a static axis dimension for bounds checking.");
+ const int64_t axisDim = inputType.getDimSize(axis);
+ if (axisDim >= std::numeric_limits<int32_t>::max())
+ return rewriter.notifyMatchFailure(
+ op, "Axis dimension is too large to narrow safely.");
+
+ const Type resultType = op.getOutput().getType();
+ const Type newResultType = typeConverter->convertType(resultType);
+ rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
+ adaptor.getInput(), axis);
+ return success();
+ }
+};
+
+class ConvertCastOpWithBoundsChecking
+ : public OpConversionPattern<tosa::CastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
+ const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
+ if (!inputType || !resultType)
+ return failure();
+
+ const auto elementInputIntType =
+ dyn_cast<IntegerType>(inputType.getElementType());
+ const auto elementResultIntType =
+ dyn_cast<IntegerType>(resultType.getElementType());
+ if (elementInputIntType && elementResultIntType &&
+ elementInputIntType.getWidth() > elementResultIntType.getWidth())
+ return rewriter.notifyMatchFailure(
+ op, "Narrowing cast may lead to data loss.");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(
+ op, typeConverter->convertType(resultType), adaptor.getInput());
+ return success();
+ }
+};
+
+template <typename OpTy>
+class ConvertTypedOp : public OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ return convertGenericOp(op, adaptor.getOperands(), rewriter,
+ this->getTypeConverter());
+ }
+};
+
+struct TosaNarrowI64ToI32
+ : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
+public:
+ explicit TosaNarrowI64ToI32() = default;
+ explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
+ : TosaNarrowI64ToI32() {
+ this->aggressiveRewrite = options.aggressiveRewrite;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) -> Type { return type; });
+ typeConverter.addConversion([](IntegerType type) -> Type {
+ if (!type.isInteger(64))
+ return type;
+ return IntegerType::get(type.getContext(), 32);
+ });
+ typeConverter.addConversion(
+ [&typeConverter](RankedTensorType type) -> Type {
+ const Type elementType = type.getElementType();
+ if (!elementType.isInteger(64))
+ return type;
+ return RankedTensorType::get(type.getShape(),
+ typeConverter.convertType(elementType));
+ });
+
+ const auto materializeCast = [](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) -> Value {
+ if (inputs.size() != 1)
+ return Value();
+ return tosa::CastOp::create(builder, loc, resultType, inputs.front());
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ typeConverter.addTypeAttributeConversion(
+ [](IntegerType type, IntegerAttr attribute) -> Attribute {
+ const APInt value = attribute.getValue().truncSSat(32);
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
+ value);
+ });
+ typeConverter.addTypeAttributeConversion(
+ [&typeConverter](ShapedType type,
+ DenseIntElementsAttr attr) -> Attribute {
+ const ShapedType newType =
+ cast<ShapedType>(typeConverter.convertType(type));
+ const auto oldElementType = cast<IntegerType>(type.getElementType());
+ const auto newElementType =
+ cast<IntegerType>(newType.getElementType());
+ if (oldElementType.getWidth() == newElementType.getWidth())
+ return attr;
+
+ DenseElementsAttr mapped =
+ attr.mapValues(newElementType, [&](const APInt &v) {
+ return v.truncSSat(newElementType.getWidth());
+ });
+ return mapped;
+ });
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalDialect<tosa::TosaDialect>(
+ [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes()) &&
+ typeConverter.isLegal(op->getOperandTypes());
+ });
+ if (convertFunctionBoundaries) {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [&typeConverter](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
+ const FunctionType funcType =
+ op->getParentOfType<func::FuncOp>().getFunctionType();
+ return llvm::equal(op.getOperandTypes(), funcType.getResults());
+ });
+ } else {
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return true; });
+ target.addDynamicallyLegalOp<func::ReturnOp>(
+ [](func::ReturnOp op) { return true; });
+ }
+
+ RewritePatternSet patterns(context);
+ if (convertFunctionBoundaries) {
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ }
+ if (aggressiveRewrite) {
+ patterns.add<ConvertGenericOp>(typeConverter, context);
+ } else {
+ // Tensor
+ patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
+ // Data layout
+ patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
+ // Type conversion
+ patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
+ // Controlflow
+ patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
+ patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
+ }
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
new file mode 100644
index 0000000000000..1a36177a37033
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// CHECK-LABEL: test_i64_argmax_large_axis_dim
+func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
+ // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_convert_input_parameters
+// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64>
+// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32>
+func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> {
+ // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+ // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
+ %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
+
+ // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
+ return %1 : tensor<1x513x513x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64>
+// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32>
+func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32>
+ // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
+ // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32>
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64>
+func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> {
+ // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32>
+ // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32>
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) {
+ // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[ADD]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ } else {
+ // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ // COMMON: tosa.yield %[[SUB]] : tensor<i32>
+ tosa.yield %1 : tensor<i64>
+ }
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64>
+ // DEFAULT: return %[[OUT]] : tensor<i64>
+ // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32>
+ return %0 : tensor<i64>
+}
+
+// -----
+
+// CHECK-LABEL: test_const
+func.func @test_const() -> tensor<2xi64> {
+ // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
+ // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
+ // DEFAULT: return %[[OUT]] : tensor<2xi64>
+ // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
+ return %0 : tensor<2xi64>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
new file mode 100644
index 0000000000000..a14483fcdd7b0
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
+// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax
+func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
+
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64>
+ // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax_cast
+func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
+ // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32>
+ %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32>
+ return %1 : tensor<1x513x513xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_i64_argmax_large_axis_dim
+func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}}
+ %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
+ return %0 : tensor<1x513x513xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_add
+func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
+ %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor<i1>) -> tensor<1xi32> {
+ // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32>
+ %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> {
+ // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32>
+ %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64>
+ // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
+ // COMMON: tosa.yield %[[CAST]] : tensor<1xi32>
+ tosa.yield %2 : tensor<1xi32>
+ } else {
+ tosa.yield %arg1 : tensor<1xi32>
+ }
+ // COMMON: return %[[IF_RESULT]] : tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_concat
+func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> {
+ // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32>
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64>
+ return %0 : tensor<26x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_pad
+func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> {
+ %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32>
+ %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64>
+ return %1 : tensor<15x23x5xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_reshape
+func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> {
+ %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32>
+ %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64>
+ return %0 : tensor<1x819xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_reverse
+func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
+ // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
+ return %0 : tensor<13x21x3xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice
+func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> {
+ %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64>
+ return %2 : tensor<4x11x1xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_tile
+func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> {
+ %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32>
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64>
+ return %0 : tensor<39x21x6xi64>
+}
+
+// -----
+
+// CHECK-LABEL: transpose
+func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> {
+ // COMMON: tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32>
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64>
+ return %1 : tensor<3x13x21xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_transition_to_i64
+func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> {
+ // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
+ %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64>
+ // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64>
+ // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64>
+ // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32>
+ return %2 : tensor<1xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_transition_from_i64
+func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> {
+ // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32>
+ // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
+ // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
+ %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
+ // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32>
+ %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
+ // COMMON: return %[[OUT_CAST]] : tensor<1xi32>
+ return %2 : tensor<1xi32>
+}
More information about the Mlir-commits
mailing list