[Mlir-commits] [mlir] Revert "[mlir][tosa] Add a pass to narrow i64 to i32 (#165581)" (PR #168538)
Luke Hutton
llvmlistbot at llvm.org
Tue Nov 18 06:03:34 PST 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/168538
This reverts commit c61c5d29334c7ff044ba46bff17e1f3d57e230a3.
>From 62d75f687f81d44258a516ceb4760b3a100b96d1 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 18 Nov 2025 14:01:22 +0000
Subject: [PATCH] Revert "[mlir][tosa] Add a pass to narrow i64 to i32
(#165581)"
This reverts commit c61c5d29334c7ff044ba46bff17e1f3d57e230a3.
---
.../mlir/Dialect/Tosa/Transforms/Passes.td | 23 --
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 -
.../Tosa/Transforms/TosaNarrowI64ToI32.cpp | 310 ------------------
.../tosa-narrow-i64-to-i32-aggressive.mlir | 81 -----
.../Dialect/Tosa/tosa-narrow-i64-to-i32.mlir | 162 ---------
5 files changed, 577 deletions(-)
delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
delete mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
delete mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 420e58192b8fd..14b00b04ccc18 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -166,27 +166,4 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
}
-def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
- let summary = "Narrow I64 TOSA operations to I32";
- let description = [{
- This pass narrows TOSA operations with 64-bit integer tensor types to
- 32-bit integer tensor types. This can be useful for backends that do not
- support the EXT-INT64 extension of TOSA.
- }];
-
- let options = [
- Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false",
- "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing"
- "is safe. This option may lead to data loss if not used carefully.">,
- Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false",
- "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
- "be inserted at the I/O boundaries.">
- ];
-
- let dependentDialects = [
- "func::FuncDialect",
- "tosa::TosaDialect",
- ];
-}
-
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 987ce4ed870c9..41b338d6e7189 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -12,7 +12,6 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaTypeConverters.cpp
TosaProfileCompliance.cpp
TosaValidation.cpp
- TosaNarrowI64ToI32.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
deleted file mode 100644
index ddaf7d8a5e033..0000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp
+++ /dev/null
@@ -1,310 +0,0 @@
-//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass narrows TOSA operations with 64-bit integer tensor types to
-// 32-bit integer tensor types. This can be useful for backends that do not
-// support the EXT-INT64 extension of TOSA. The pass has two options:
-//
-// - aggressive-rewrite - If enabled, all TOSA operations are rewritten,
-// regardless or whether the narrowing is safe. This option may lead to
-// data loss if not used carefully.
-// - convert-function-boundaries - If enabled, the pass will convert function
-// I/O types as well. Otherwise casts will be inserted at the I/O
-// boundaries.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/IR/Verifier.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace tosa {
-#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
-#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
-} // namespace tosa
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-LogicalResult convertGenericOp(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter,
- const TypeConverter *typeConverter) {
- // Convert types of results
- SmallVector<Type, 4> newResults;
- if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
- return failure();
-
- // Create a new operation state
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, {}, op->getSuccessors());
-
- for (const NamedAttribute &namedAttribute : op->getAttrs()) {
- const Attribute attribute = namedAttribute.getValue();
-
- // Convert integer attribute type
- if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
- Type type = typeAttr.getValue();
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(type, attribute);
- if (!convertedAttribute)
- return rewriter.notifyMatchFailure(op,
- "Failed to convert type attribute.");
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
- const Type type = denseElementsAttr.getType();
- const std::optional<Attribute> convertedAttribute =
- typeConverter->convertTypeAttribute(type, denseElementsAttr);
- if (!convertedAttribute)
- return rewriter.notifyMatchFailure(
- op, "Failed to convert dense elements attribute.");
- state.addAttribute(namedAttribute.getName(), convertedAttribute.value());
- continue;
- }
-
- state.addAttribute(namedAttribute.getName(), attribute);
- }
-
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
- return failure();
- }
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
-}
-
-// ===========================
-// Aggressive rewrite patterns
-// ===========================
-
-class ConvertGenericOp : public ConversionPattern {
-public:
- ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
- : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (!isa<tosa::TosaOp>(op))
- return rewriter.notifyMatchFailure(
- op,
- "Support for operations other than TOSA has not been implemented.");
-
- return convertGenericOp(op, operands, rewriter, typeConverter);
- }
-};
-
-// ===============================
-// Bounds checked rewrite patterns
-// ===============================
-
-class ConvertArgMaxOpWithBoundsChecking
- : public OpConversionPattern<tosa::ArgMaxOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- // Output type can be narrowed based on the size of the axis dimension
- const int32_t axis = op.getAxis();
- const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
- if (!inputType || !inputType.isStaticDim(axis))
- return rewriter.notifyMatchFailure(
- op, "Requires a static axis dimension for bounds checking.");
- const int64_t axisDim = inputType.getDimSize(axis);
- if (axisDim >= std::numeric_limits<int32_t>::max())
- return rewriter.notifyMatchFailure(
- op, "Axis dimension is too large to narrow safely.");
-
- const Type resultType = op.getOutput().getType();
- const Type newResultType = typeConverter->convertType(resultType);
- rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
- adaptor.getInput(), axis);
- return success();
- }
-};
-
-class ConvertCastOpWithBoundsChecking
- : public OpConversionPattern<tosa::CastOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
- const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
- if (!inputType || !resultType)
- return failure();
-
- const auto elementInputIntType =
- dyn_cast<IntegerType>(inputType.getElementType());
- const auto elementResultIntType =
- dyn_cast<IntegerType>(resultType.getElementType());
- if (elementInputIntType && elementResultIntType &&
- elementInputIntType.getWidth() > elementResultIntType.getWidth())
- return rewriter.notifyMatchFailure(
- op, "Narrowing cast may lead to data loss.");
-
- rewriter.replaceOpWithNewOp<tosa::CastOp>(
- op, typeConverter->convertType(resultType), adaptor.getInput());
- return success();
- }
-};
-
-template <typename OpTy>
-class ConvertTypedOp : public OpConversionPattern<OpTy> {
- using OpConversionPattern<OpTy>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- return convertGenericOp(op, adaptor.getOperands(), rewriter,
- this->getTypeConverter());
- }
-};
-
-struct TosaNarrowI64ToI32
- : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
-public:
- explicit TosaNarrowI64ToI32() = default;
- explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options)
- : TosaNarrowI64ToI32() {
- this->aggressiveRewrite = options.aggressiveRewrite;
- this->convertFunctionBoundaries = options.convertFunctionBoundaries;
- }
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
-
- TypeConverter typeConverter;
- typeConverter.addConversion([](Type type) -> Type { return type; });
- typeConverter.addConversion([](IntegerType type) -> Type {
- if (!type.isInteger(64))
- return type;
- return IntegerType::get(type.getContext(), 32);
- });
- typeConverter.addConversion(
- [&typeConverter](RankedTensorType type) -> Type {
- const Type elementType = type.getElementType();
- if (!elementType.isInteger(64))
- return type;
- return RankedTensorType::get(type.getShape(),
- typeConverter.convertType(elementType));
- });
-
- const auto materializeCast = [](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) -> Value {
- if (inputs.size() != 1)
- return Value();
- return tosa::CastOp::create(builder, loc, resultType, inputs.front());
- };
- typeConverter.addSourceMaterialization(materializeCast);
- typeConverter.addTargetMaterialization(materializeCast);
-
- typeConverter.addTypeAttributeConversion(
- [](IntegerType type, IntegerAttr attribute) -> Attribute {
- const APInt value = attribute.getValue().truncSSat(32);
- return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
- value);
- });
- typeConverter.addTypeAttributeConversion(
- [&typeConverter](ShapedType type,
- DenseIntElementsAttr attr) -> Attribute {
- const ShapedType newType =
- cast<ShapedType>(typeConverter.convertType(type));
- const auto oldElementType = cast<IntegerType>(type.getElementType());
- const auto newElementType =
- cast<IntegerType>(newType.getElementType());
- if (oldElementType.getWidth() == newElementType.getWidth())
- return attr;
-
- DenseElementsAttr mapped =
- attr.mapValues(newElementType, [&](const APInt &v) {
- return v.truncSSat(newElementType.getWidth());
- });
- return mapped;
- });
-
- ConversionTarget target(*context);
- target.addDynamicallyLegalDialect<tosa::TosaDialect>(
- [&typeConverter](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes()) &&
- typeConverter.isLegal(op->getOperandTypes());
- });
- if (convertFunctionBoundaries) {
- target.addDynamicallyLegalOp<func::FuncOp>(
- [&typeConverter](func::FuncOp op) {
- return typeConverter.isSignatureLegal(op.getFunctionType()) &&
- typeConverter.isLegal(&op.getBody());
- });
- target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
- const FunctionType funcType =
- op->getParentOfType<func::FuncOp>().getFunctionType();
- return llvm::equal(op.getOperandTypes(), funcType.getResults());
- });
- } else {
- target.addDynamicallyLegalOp<func::FuncOp>(
- [](func::FuncOp op) { return true; });
- target.addDynamicallyLegalOp<func::ReturnOp>(
- [](func::ReturnOp op) { return true; });
- }
-
- RewritePatternSet patterns(context);
- if (convertFunctionBoundaries) {
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- patterns, typeConverter);
- populateReturnOpTypeConversionPattern(patterns, typeConverter);
- }
- if (aggressiveRewrite) {
- patterns.add<ConvertGenericOp>(typeConverter, context);
- } else {
- // Tensor
- patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
- // Data layout
- patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
- // Type conversion
- patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
- // Controlflow
- patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
- patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
- }
-
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns))))
- signalPassFailure();
- }
-};
-
-} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
deleted file mode 100644
index 1a36177a37033..0000000000000
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
-// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
-
-// CHECK-LABEL: test_i64_argmax_large_axis_dim
-func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
- // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32>
- %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
- return %0 : tensor<1x513x513xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_convert_input_parameters
-// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64>
-// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32>
-func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> {
- // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
- // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
- // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32>
- %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32>
-
- // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
- %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32>
- return %1 : tensor<1x513x513x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_add
-// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64>
-// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32>
-func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
- // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32>
- // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32>
- // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
- // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64>
- // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64>
- // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
- // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32>
- %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
- return %0 : tensor<13x21x3xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_regions
-// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64>
-func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> {
- // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32>
- // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32>
- // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if
- %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) {
- // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
- // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
- // COMMON: tosa.yield %[[ADD]] : tensor<i32>
- tosa.yield %1 : tensor<i64>
- } else {
- // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
- // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
- %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64>
- // COMMON: tosa.yield %[[SUB]] : tensor<i32>
- tosa.yield %1 : tensor<i64>
- }
- // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64>
- // DEFAULT: return %[[OUT]] : tensor<i64>
- // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32>
- return %0 : tensor<i64>
-}
-
-// -----
-
-// CHECK-LABEL: test_const
-func.func @test_const() -> tensor<2xi64> {
- // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32>
- %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64>
- // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64>
- // DEFAULT: return %[[OUT]] : tensor<2xi64>
- // FUNCBOUND: return %[[CONST]] : tensor<2xi32>
- return %0 : tensor<2xi64>
-}
diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
deleted file mode 100644
index a14483fcdd7b0..0000000000000
--- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir
+++ /dev/null
@@ -1,162 +0,0 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT
-// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND
-
-// -----
-
-// CHECK-LABEL: test_i64_argmax
-func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> {
- // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
- %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
-
- // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64>
- // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32>
- return %0 : tensor<1x513x513xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_i64_argmax_cast
-func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> {
- // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32>
- %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64>
- // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32>
- %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32>
- return %1 : tensor<1x513x513xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_i64_argmax_large_axis_dim
-func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> {
- // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}}
- %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64>
- return %0 : tensor<1x513x513xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_add
-func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
- // expected-error @+1 {{failed to legalize operation 'tosa.add'}}
- %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
- return %0 : tensor<13x21x3xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_regions
-func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor<i1>) -> tensor<1xi32> {
- // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32>
- %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> {
- // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32>
- %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64>
- // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32>
- %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
- // COMMON: tosa.yield %[[CAST]] : tensor<1xi32>
- tosa.yield %2 : tensor<1xi32>
- } else {
- tosa.yield %arg1 : tensor<1xi32>
- }
- // COMMON: return %[[IF_RESULT]] : tensor<1xi32>
- return %0 : tensor<1xi32>
-}
-
-// -----
-
-// CHECK-LABEL: test_concat
-func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> {
- // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32>
- %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64>
- return %0 : tensor<26x21x3xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_pad
-func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> {
- %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
- // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32>
- %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64>
- return %1 : tensor<15x23x5xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_reshape
-func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> {
- %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2>
- // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32>
- %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64>
- return %0 : tensor<1x819xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_reverse
-func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
- // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
- %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
- return %0 : tensor<13x21x3xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_slice
-func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> {
- %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
- %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
- // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32>
- %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64>
- return %2 : tensor<4x11x1xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_tile
-func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> {
- %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
- // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32>
- %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64>
- return %0 : tensor<39x21x6xi64>
-}
-
-// -----
-
-// CHECK-LABEL: transpose
-func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> {
- // COMMON: tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32>
- %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64>
- return %1 : tensor<3x13x21xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_transition_to_i64
-func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> {
- // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32>
- %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64>
- // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
- %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
- // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
- %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64>
- // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64>
- // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64>
- // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32>
- return %2 : tensor<1xi64>
-}
-
-// -----
-
-// CHECK-LABEL: test_transition_from_i64
-func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> {
- // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32>
- // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32>
- // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32>
- %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64>
- // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32>
- %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64>
- // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32>
- %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32>
- // COMMON: return %[[OUT_CAST]] : tensor<1xi32>
- return %2 : tensor<1xi32>
-}
More information about the Mlir-commits
mailing list