[Mlir-commits] [mlir] [mlir][tosa] Use traits to check output type aligns with input type (PR #193961)

Luke Hutton llvmlistbot at llvm.org
Wed Apr 29 02:04:59 PDT 2026


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/193961

>From d87c8a6f7337c83d38024c021d0c9e198d270882 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 24 Apr 2026 12:37:23 +0000
Subject: [PATCH 1/2] [mlir][tosa] Make tosa.reverse an elementwise unary op

Reduces code duplication and ensures the output shape aligns with
the input shape.

Change-Id: I488975d064510a3580678fb168ea3b5d27a2ec9c
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td |  6 ++---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp         | 16 +------------
 mlir/test/Dialect/Tosa/ops.mlir              |  7 ++++++
 mlir/test/Dialect/Tosa/verifier.mlir         | 24 ++++++++++++++++++++
 4 files changed, 34 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 207618adc1352..05f2df9610d0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -536,7 +536,7 @@ def Tosa_MaxPool2dAdaptiveOp
     This performs a max pooling over the given input tensor. A sliding window of
     size given by <kernel size> is passed over the input tensor, with the
     maximum value being placed in the output tensor.
-    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
     pad arguments as inputs rather than attributes.
   }];
 
@@ -2348,9 +2348,7 @@ def Tosa_ReshapeBlockScaledOp
 //===----------------------------------------------------------------------===//
 // Operator: reverse
 //===----------------------------------------------------------------------===//
-def Tosa_ReverseOp: Tosa_Op<"reverse", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_ReverseOp: Tosa_ElementwiseUnaryOp<"reverse", [Pure]> {
   let summary = "Reverse operator.";
 
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fa4bc120e9c1e..c9a61971a8134 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -5430,12 +5430,7 @@ LogicalResult WhileOp::verify() {
 }
 
 LogicalResult ReverseOp::verify() {
-  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
-                             /* outType = */ getOutput().getType())
-          .failed())
-    return failure();
   TensorType inputType = getInput1().getType();
-  TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
 
   if (reverseAxis < 0)
@@ -5449,16 +5444,7 @@ LogicalResult ReverseOp::verify() {
              << inputRank << ") to be larger than reverse axis (" << reverseAxis
              << ")";
   }
-  if (outputType.hasRank()) {
-    int64_t outputRank = outputType.getRank();
-    if (inputType.hasRank() && outputRank != inputType.getRank())
-      return emitOpError(
-          "expect output tensor rank to be equal to input tensor rank");
-    if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
-      return emitOpError("expect output tensor rank (")
-             << outputRank << ") to be larger than reverse axis ("
-             << reverseAxis << ")";
-  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a39f593b14263..5c368b3da4ff5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -886,6 +886,13 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: reverse_unranked
+func.func @test_reverse_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // -----
 // CHECK-LABEL: slice
 func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 1572df5357877..051797b2c69ff 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -127,6 +127,30 @@ func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
 
 // -----
 
+func.func @test_reverse_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x3xf32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same element type for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x3xf32>
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+func.func @test_reverse_shape_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x4xi32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<2x4xi32>
+  return %0 : tensor<2x4xi32>
+}
+
+// -----
+
+func.func @test_reverse_rank_mismatch(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
+  // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+  %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<1x2x3xi32>
+  return %0 : tensor<1x2x3xi32>
+}
+
+// -----
+
 func.func @test_slice_invalid_output_rank() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
   %start = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>

>From 45c2c27c05b2128f9fc54e4ae0831ae44e4ba5e7 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 29 Apr 2026 08:33:29 +0000
Subject: [PATCH 2/2] Use "SameOperands*" traits manually rather than define as
 an elementwise op

Defining as an elementwise op breaks optimizations.

Change-Id: I43a6bd6e5f081e879f1f5fe8cb5cc0f12af3e0a2
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 +++++++-
 mlir/test/Dialect/Tosa/verifier.mlir         | 2 +-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 05f2df9610d0c..1d337c188fb8a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2348,7 +2348,13 @@ def Tosa_ReshapeBlockScaledOp
 //===----------------------------------------------------------------------===//
 // Operator: reverse
 //===----------------------------------------------------------------------===//
-def Tosa_ReverseOp: Tosa_ElementwiseUnaryOp<"reverse", [Pure]> {
+def Tosa_ReverseOp: Tosa_Op<"reverse", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                                ["inferReturnTypeComponents"]>,
+      SameOperandsAndResultRank,
+      SameOperandsAndResultShape,
+      SameOperandsAndResultElementType,
+      Pure]> {
   let summary = "Reverse operator.";
 
   let description = [{
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 051797b2c69ff..57d23ffc659b5 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -144,7 +144,7 @@ func.func @test_reverse_shape_mismatch(%arg0: tensor<2x3xi32>) -> tensor<2x4xi32
 // -----
 
 func.func @test_reverse_rank_mismatch(%arg0: tensor<2x3xi32>) -> tensor<1x2x3xi32> {
-  // expected-error at +1 {{'tosa.reverse' op requires the same shape for all operands and results}}
+  // expected-error at +1 {{'tosa.reverse' op result type has different rank than operands}}
   %0 = tosa.reverse %arg0 {axis = 1 : i32} : (tensor<2x3xi32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }



More information about the Mlir-commits mailing list