[Mlir-commits] [mlir] [mlir][tosa] Make tosa.reverse an elementwise unary op (PR #193961)

Luke Hutton llvmlistbot at llvm.org
Fri Apr 24 05:55:16 PDT 2026


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/193961

Reduces code duplication and ensures the output shape aligns with the input shape.

>From af6521120caab4dacf5ca9265c0efd6267879c04 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 24 Apr 2026 12:37:23 +0000
Subject: [PATCH] [mlir][tosa] Make tosa.reverse an elementwise unary op

Reduces code duplication and ensures the output shape aligns with
the input shape.

Change-Id: I488975d064510a3580678fb168ea3b5d27a2ec9c
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  6 ++---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 15 +-----------
 mlir/test/Dialect/Tosa/ops.mlir              |  7 ++++++
 mlir/test/Dialect/Tosa/verifier.mlir         | 24 ++++++++++++++++++++
 4 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..05f2df9610d0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -536,7 +536,7 @@ def Tosa_MaxPool2dAdaptiveOp
     This performs a max pooling over the given input tensor. A sliding window of
     size given by <kernel size> is passed over the input tensor, with the
     maximum value being placed in the output tensor.
-    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
     pad arguments as inputs rather than attributes.
   }];
 
@@ -2348,9 +2348,7 @@ def Tosa_ReshapeBlockScaledOp
 //===----------------------------------------------------------------------===//
 // Operator: reverse
 //===----------------------------------------------------------------------===//
-def Tosa_ReverseOp: Tosa_Op<"reverse", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_ReverseOp: Tosa_ElementwiseUnaryOp<"reverse", [Pure]> {
   let summary = "Reverse operator.";
 
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fa4bc120e9c1e..596b232040466 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -5430,10 +5430,6 @@ LogicalResult WhileOp::verify() {
 }
 
 LogicalResult ReverseOp::verify() {
-  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
-                             /* outType = */ getOutput().getType())
-          .failed())
-    return failure();
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
@@ -5449,16 +5445,7 @@ LogicalResult ReverseOp::verify() {
              << inputRank << ") to be larger than reverse axis (" << reverseAxis
              << ")";
   }
-  if (outputType.hasRank()) {
-    int64_t outputRank = outputType.getRank();
-    if (inputType.hasRank() && outputRank != inputType.getRank())
-      return emitOpError(
-          "expect output tensor rank to be equal to input tensor rank");
-    if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
-      return emitOpError("expect output tensor rank (")
-             << outputRank << ") to be larger than reverse axis ("
-             << reverseAxis << ")";
-  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a39f593b14263..5c368b3da4ff5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -886,6 +886,13 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: reverse_unranked
+func.func @test_reverse_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // -----
 // CHECK-LABEL: slice
 func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 1572df5357877..051797b2c69ff 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -127,6 +127,30 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
 
 // -----
 
+func.func @test_reverse_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x3xf32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same element type for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x3xf32>
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+func.func @test_reverse_shape_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x4xi32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x4xi32>
+  return %0 : tensor<2x4xi32>
+}
+
+// -----
+
+func.func @test_reverse_rank_mismatch(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<1x2x3xi32>
+  return %0 : tensor<1x2x3xi32>
+}
+
+// -----
+
 func.func @test_slice_invalid_output_rank() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>



More information about the Mlir-commits mailing list