[Mlir-commits] [mlir] [mlir][tosa] Make tosa.reverse an elementwise unary op (PR #193961)
Luke Hutton
llvmlistbot at llvm.org
Fri Apr 24 05:55:16 PDT 2026
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/193961
Reduces code duplication and ensures the output shape aligns with the input shape.
>From af6521120caab4dacf5ca9265c0efd6267879c04 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 24 Apr 2026 12:37:23 +0000
Subject: [PATCH] [mlir][tosa] Make tosa.reverse an elementwise unary op
Reduces code duplication and ensures the output shape aligns with
the input shape.
Change-Id: I488975d064510a3580678fb168ea3b5d27a2ec9c
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 6 ++---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 +-----------
mlir/test/Dialect/Tosa/ops.mlir | 7 ++++++
mlir/test/Dialect/Tosa/verifier.mlir | 24 ++++++++++++++++++++
4 files changed, 34 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..05f2df9610d0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -536,7 +536,7 @@ def Tosa_MaxPool2dAdaptiveOp
This performs a max pooling over the given input tensor. A sliding window of
size given by <kernel size> is passed over the input tensor, with the
maximum value being placed in the output tensor.
- Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
+ Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
pad arguments as inputs rather than attributes.
}];
@@ -2348,9 +2348,7 @@ def Tosa_ReshapeBlockScaledOp
//===----------------------------------------------------------------------===//
// Operator: reverse
//===----------------------------------------------------------------------===//
-def Tosa_ReverseOp: Tosa_Op<"reverse", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_ReverseOp: Tosa_ElementwiseUnaryOp<"reverse", [Pure]> {
let summary = "Reverse operator.";
let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fa4bc120e9c1e..596b232040466 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -5430,10 +5430,6 @@ LogicalResult WhileOp::verify() {
}
LogicalResult ReverseOp::verify() {
- if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
- /* outType = */ getOutput().getType())
- .failed())
- return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
@@ -5449,16 +5445,7 @@ LogicalResult ReverseOp::verify() {
<< inputRank << ") to be larger than reverse axis (" << reverseAxis
<< ")";
}
- if (outputType.hasRank()) {
- int64_t outputRank = outputType.getRank();
- if (inputType.hasRank() && outputRank != inputType.getRank())
- return emitOpError(
- "expect output tensor rank to be equal to input tensor rank");
- if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
- return emitOpError("expect output tensor rank (")
- << outputRank << ") to be larger than reverse axis ("
- << reverseAxis << ")";
- }
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a39f593b14263..5c368b3da4ff5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -886,6 +886,13 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: reverse_unranked
+func.func @test_reverse_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// -----
// CHECK-LABEL: slice
func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 1572df5357877..051797b2c69ff 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -127,6 +127,30 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
// -----
+func.func @test_reverse_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x3xf32> {
+ // expected-error at +1 {{'tosa.reverse' op requires the same element type for all operands and results}}
+ %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x3xf32>
+ return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+func.func @test_reverse_shape_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x4xi32> {
+ // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+ %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x4xi32>
+ return %0 : tensor<2x4xi32>
+}
+
+// -----
+
+func.func @test_reverse_rank_mismatch(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+ // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+ %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<1x2x3xi32>
+ return %0 : tensor<1x2x3xi32>
+}
+
+// -----
+
func.func @test_slice_invalid_output_rank() {
%0 = tensor.empty() : tensor<4x31x31xf32>
%start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
More information about the Mlir-commits
mailing list