[Mlir-commits] [mlir] 8dea784 - [mlir][tosa] Add tosa shape inference with InferReturnTypeComponent

Rob Suderman llvmlistbot at llvm.org
Thu Jul 1 16:07:22 PDT 2021


Author: Rob Suderman
Date: 2021-07-01T16:04:26-07:00
New Revision: 8dea784b3ed7df3edd9e3b59b1e1b58d2a4ac175

URL: https://github.com/llvm/llvm-project/commit/8dea784b3ed7df3edd9e3b59b1e1b58d2a4ac175
DIFF: https://github.com/llvm/llvm-project/commit/8dea784b3ed7df3edd9e3b59b1e1b58d2a4ac175.diff

LOG: [mlir][tosa] Add tosa shape inference with InferReturnTypeComponent

Added InferReturnTypeComponents for NAry operations, reshape, and reverse.
With the additional tosa-infer-shapes pass, we can infer/propagate shapes
across a set of TOSA operations. Current version does not modify the
FuncOp type by inserting an unrealized conversion cast prior to any new
non-matchin returns.

Differential Revision: https://reviews.llvm.org/D105312

Added: 
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
    mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 0af72312af28f..b0d5eb79fbfcf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3a1f9d26be118..06867ef199e11 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -17,6 +17,7 @@
 include "mlir/IR/OpBase.td"
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
 
@@ -284,7 +285,10 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: clamp
 //===----------------------------------------------------------------------===//
-def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_ClampOp : Tosa_Op<"clamp", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes clamp(features, min, max).";
 
   let description = [{
@@ -309,7 +313,10 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: reluN
 //===----------------------------------------------------------------------===//
-def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_ReluNOp : Tosa_Op<"reluN", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes rectified linear: `max(features, N)`.";
 
   let description = [{
@@ -330,8 +337,10 @@ def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: sigmoid
 //===----------------------------------------------------------------------===//
-def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect,
-                             SameOperandsAndResultType]> {
+def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes elementwise sigmoid of input.";
 
   let description = [{
@@ -354,7 +363,10 @@ def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: tanh
 //===----------------------------------------------------------------------===//
-def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_TanhOp : Tosa_Op<"tanh", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Computes elementwise hyperbolic tangent of input";
 
   let description = [{
@@ -382,8 +394,10 @@ def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: add
 //===----------------------------------------------------------------------===//
-def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect,
-                         Commutative]> {
+def Tosa_AddOp : Tosa_Op<"add", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Elementwise addition operator";
 
   let description = [{
@@ -404,9 +418,10 @@ def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: arithmetic_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift",
-                                          [ResultsBroadcastableShape,
-                                          NoSideEffect]> {
+def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Elementwise Arithmetic Right Shift";
 
   let description = [{
@@ -429,8 +444,10 @@ def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift",
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_and
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape,
-                                NoSideEffect, Commutative]> {
+def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Bitwise AND operator";
 
   let description = [{
@@ -451,8 +468,10 @@ def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_or
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape,
-                                              NoSideEffect, Commutative]> {
+def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Bitwise OR operator";
 
   let description = [{
@@ -473,8 +492,10 @@ def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_xor
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape,
-                                                NoSideEffect, Commutative]> {
+def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Bitwise XOR operator";
 
   let description = [{
@@ -495,8 +516,10 @@ def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: div
 //===----------------------------------------------------------------------===//
-def Tosa_DivOp : Tosa_Op<"div", [ResultsBroadcastableShape,
-                                 NoSideEffect]> {
+def Tosa_DivOp : Tosa_Op<"div", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Integer divide operator";
 
   let description = [{
@@ -517,8 +540,10 @@ def Tosa_DivOp : Tosa_Op<"div", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: logical_and
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape,
-                                                Commutative, NoSideEffect]> {
+def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of x AND y element-wise.";
 
   let description = [{
@@ -539,9 +564,10 @@ def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: logical_left_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift",
-                                      [ResultsBroadcastableShape,
-                                       NoSideEffect]> {
+def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Elementwise Logical Left Shift";
 
   let description = [{
@@ -562,9 +588,9 @@ def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift",
 //===----------------------------------------------------------------------===//
 // Operator: logical_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift",
-                                       [ResultsBroadcastableShape,
-                                        NoSideEffect]> {
+def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>, ResultsBroadcastableShape,
+    NoSideEffect]> {
   let summary = "Elementwise Logical Right Shift";
 
   let description = [{
@@ -586,8 +612,10 @@ def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift",
 //===----------------------------------------------------------------------===//
 // Operator: logical_or
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape,
-                                              Commutative, NoSideEffect]> {
+def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of x OR y element-wise.";
 
   let description = [{
@@ -608,8 +636,10 @@ def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: logical_xor
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape,
-                                                Commutative, NoSideEffect]> {
+def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of x XOR y element-wise.";
 
   let description = [{
@@ -630,8 +660,10 @@ def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: maximum
 //===----------------------------------------------------------------------===//
-def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape,
-                                         NoSideEffect, Commutative]> {
+def Tosa_MaximumOp : Tosa_Op<"maximum", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Elementwise Maximum";
 
   let description = [{
@@ -652,8 +684,10 @@ def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: minimum
 //===----------------------------------------------------------------------===//
-def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape,
-                                         NoSideEffect, Commutative]> {
+def Tosa_MinimumOp : Tosa_Op<"minimum", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Elementwise Minimum";
 
   let description = [{
@@ -674,8 +708,10 @@ def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: mul
 //===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect,
-                                 Commutative]> {
+def Tosa_MulOp : Tosa_Op<"mul", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect, Commutative]> {
   let summary = "Multiplication operator";
 
   let description = [{
@@ -698,7 +734,10 @@ def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: pow
 //===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> {
+def Tosa_PowOp : Tosa_Op<"pow", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Computes the power of one value to another.";
 
   let description = [{
@@ -720,7 +759,10 @@ def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: sub
 //===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
+def Tosa_SubOp : Tosa_Op<"sub", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Elementwise subtraction operator";
 
   let description = [{
@@ -781,7 +823,10 @@ def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: abs
 //===----------------------------------------------------------------------===//
-def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_AbsOp : Tosa_Op<"abs", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise abs op";
 
   let description = [{
@@ -800,8 +845,10 @@ def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_not
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape,
-                                                NoSideEffect]> {
+def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Bitwise NOT operator";
 
   let description = [{
@@ -820,7 +867,10 @@ def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: ceil
 //===----------------------------------------------------------------------===//
-def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_CeilOp : Tosa_Op<"ceil", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise ceil op";
 
   let description = [{
@@ -839,7 +889,10 @@ def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: clz
 //===----------------------------------------------------------------------===//
-def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_ClzOp : Tosa_Op<"clz", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise count leading zero op";
 
   let description = [{
@@ -858,7 +911,10 @@ def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: exp
 //===----------------------------------------------------------------------===//
-def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_ExpOp : Tosa_Op<"exp", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise exp op";
 
   let description = [{
@@ -877,7 +933,10 @@ def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: floor
 //===----------------------------------------------------------------------===//
-def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_FloorOp : Tosa_Op<"floor", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise floor op";
 
   let description = [{
@@ -896,7 +955,10 @@ def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: log
 //===----------------------------------------------------------------------===//
-def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_LogOp : Tosa_Op<"log", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise log op";
 
   let description = [{
@@ -915,8 +977,10 @@ def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_not
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect,
-                                                SameOperandsAndResultType]> {
+def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Returns the truth value of NOT x element-wise.";
 
   let description = [{
@@ -935,8 +999,10 @@ def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: negate
 //===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect,
-                                       SameOperandsAndResultType]> {
+def Tosa_NegateOp : Tosa_Op<"negate", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise negate op";
 
   let description = [{
@@ -958,8 +1024,10 @@ def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: reciprocal
 //===----------------------------------------------------------------------===//
-def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect,
-                                               SameOperandsAndResultType]> {
+def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise reciprocal op";
 
   let description = [{
@@ -979,7 +1047,10 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect,
 //===----------------------------------------------------------------------===//
 // Operator: rsqrt
 //===----------------------------------------------------------------------===//
-def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
+def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect, SameOperandsAndResultType]> {
   let summary = "Elementwise 1/sqrt op";
 
   let description = [{
@@ -1005,7 +1076,9 @@ def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
 //===----------------------------------------------------------------------===//
 // Operator: select
 //===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> {
+def Tosa_SelectOp : Tosa_Op<"select", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, NoSideEffect]> {
   let summary = "Elementwise select operator";
 
   let description = [{
@@ -1031,8 +1104,10 @@ def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: equal
 //===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative,
-                                     NoSideEffect]> {
+def Tosa_EqualOp : Tosa_Op<"equal", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of (x == y) element-wise.";
 
   let description = [{
@@ -1052,8 +1127,10 @@ def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative,
 //===----------------------------------------------------------------------===//
 // Operator: greater
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape,
-                                         NoSideEffect]> {
+def Tosa_GreaterOp : Tosa_Op<"greater", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Returns the truth value of (x > y) element-wise.";
 
   let description = [{
@@ -1073,8 +1150,10 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: greater_equal
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape,
-                                                    NoSideEffect]> {
+def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    ResultsBroadcastableShape, NoSideEffect]> {
   let summary = "Returns the truth value of (x >= y) element-wise.";
 
   let description = [{
@@ -1269,7 +1348,9 @@ def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> {
 // Operator: reshape
 //===----------------------------------------------------------------------===//
 def Tosa_ReshapeOp: Tosa_Op<"reshape", [
-   NoSideEffect]> {
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    NoSideEffect]> {
   let summary = "Reshape operator";
 
   let description = [{
@@ -1291,7 +1372,9 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
 //===----------------------------------------------------------------------===//
 // Operator: reverse
 //===----------------------------------------------------------------------===//
-def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> {
+def Tosa_ReverseOp: Tosa_Op<"reverse", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>, NoSideEffect]> {
   let summary = "Reverse operator";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index b9032dfd351e0..b00b161aef156 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -13,11 +13,13 @@
 #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
 
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
 namespace tosa {
 
+std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index a29a1676c2647..dfa7b1f8582e3 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -15,6 +15,21 @@
 
 include "mlir/Pass/PassBase.td"
 
+def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
+  let summary = "Propagate shapes across TOSA operations";
+  let description = [{
+    Pass that uses operand types and propagates shapes to TOSA operations.
+    This includes legalizing rankless and dynamic shapes towards static.
+  }];
+
+  let constructor = "createTosaInferShapesPass()";
+  let dependentDialects = [
+    "StandardOpsDialect",
+    "tensor::TensorDialect",
+    "tosa::TosaDialect",
+  ];
+}
+
 def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
   let summary = "TOSA rank Reshape to enable Broadcasting";
   let description = [{

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 39b864ff62c02..fd744372fceab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -291,6 +291,148 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA Operator Return Type Inference.
+//===----------------------------------------------------------------------===//
+
+static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
+  for (auto it : arrayAttr) {
+    values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
+  }
+}
+
+LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapedType type = operands.front().getType().cast<ShapedType>();
+
+  auto newShape = attributes.get("new_shape").cast<ArrayAttr>();
+  llvm::SmallVector<int64_t> newShapeValue;
+  getI64Values(newShape, newShapeValue);
+
+  // We cannot infer from the total number of elements so we must take the
+  // shape attribute as exact.
+  if (!type.hasRank() || !type.hasStaticShape()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+    return success();
+  }
+
+  // Determine the number of elements covered by the slice of all static
+  // dimensions. This allows us to infer the length of the remaining dynamic
+  // dimension.
+  int64_t numElements = type.getNumElements();
+  int64_t staticMul = 1;
+  for (auto val : newShapeValue) {
+    if (val != -1) {
+      staticMul *= val;
+    }
+  }
+
+  // Determine the length of the dynamic dimension.
+  for (auto &val : newShapeValue) {
+    if (val == -1)
+      val = numElements / staticMul;
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
+  return success();
+}
+
+static LogicalResult resolveBroadcastShape(ValueRange operands,
+                                           SmallVector<int64_t> &outShape) {
+  int64_t outRank = 0;
+  for (auto operand : operands) {
+    auto type = operand.getType().cast<ShapedType>();
+    if (!type.hasRank())
+      return failure();
+    outRank = std::max<int64_t>(outRank, type.getRank());
+  }
+
+  outShape.resize(outRank, 1);
+
+  for (auto operand : operands) {
+    auto type = operand.getType().cast<ShapedType>();
+    auto shape = type.getShape();
+    auto rankDiff = outShape.size() - shape.size();
+
+    for (size_t i = 0; i < shape.size(); i++) {
+      auto dim1 = outShape[i + rankDiff];
+      auto dim2 = shape[i];
+      auto resolvedDim = dim1;
+
+      if (dim1 == 1) {
+        resolvedDim = dim2;
+      } else if (dim2 == 1) {
+        resolvedDim = dim1;
+      } else if (dim1 != dim2) {
+        return failure();
+      }
+      outShape[i + rankDiff] = resolvedDim;
+    }
+  }
+
+  return success();
+}
+
+static LogicalResult NAryInferReturnTypes(
+    ValueRange operands,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<int64_t> outShape;
+  if (resolveBroadcastShape(operands, outShape).failed()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+  } else {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  }
+  return success();
+}
+
+#define NARY_SHAPE_INFER(OP)                                                   \
+  LogicalResult OP::inferReturnTypeComponents(                                 \
+      MLIRContext *context, ::llvm::Optional<Location> location,               \
+      ValueRange operands, DictionaryAttr attributes, RegionRange regions,     \
+      SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
+    return NAryInferReturnTypes(operands, inferredReturnShapes);               \
+  }
+
+NARY_SHAPE_INFER(tosa::AbsOp)
+NARY_SHAPE_INFER(tosa::AddOp)
+NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
+NARY_SHAPE_INFER(tosa::BitwiseAndOp)
+NARY_SHAPE_INFER(tosa::BitwiseOrOp)
+NARY_SHAPE_INFER(tosa::BitwiseXorOp)
+NARY_SHAPE_INFER(tosa::BitwiseNotOp)
+NARY_SHAPE_INFER(tosa::CeilOp)
+NARY_SHAPE_INFER(tosa::ClampOp)
+NARY_SHAPE_INFER(tosa::ClzOp)
+NARY_SHAPE_INFER(tosa::DivOp)
+NARY_SHAPE_INFER(tosa::EqualOp)
+NARY_SHAPE_INFER(tosa::ExpOp)
+NARY_SHAPE_INFER(tosa::FloorOp)
+NARY_SHAPE_INFER(tosa::GreaterEqualOp)
+NARY_SHAPE_INFER(tosa::GreaterOp)
+NARY_SHAPE_INFER(tosa::LogOp)
+NARY_SHAPE_INFER(tosa::LogicalAndOp)
+NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
+NARY_SHAPE_INFER(tosa::LogicalNotOp)
+NARY_SHAPE_INFER(tosa::LogicalOrOp)
+NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
+NARY_SHAPE_INFER(tosa::LogicalXorOp)
+NARY_SHAPE_INFER(tosa::MaximumOp)
+NARY_SHAPE_INFER(tosa::MinimumOp)
+NARY_SHAPE_INFER(tosa::MulOp)
+NARY_SHAPE_INFER(tosa::NegateOp)
+NARY_SHAPE_INFER(tosa::PowOp)
+NARY_SHAPE_INFER(tosa::ReciprocalOp)
+NARY_SHAPE_INFER(tosa::ReluNOp)
+NARY_SHAPE_INFER(tosa::ReverseOp)
+NARY_SHAPE_INFER(tosa::RsqrtOp)
+NARY_SHAPE_INFER(tosa::SelectOp)
+NARY_SHAPE_INFER(tosa::SubOp)
+NARY_SHAPE_INFER(tosa::TanhOp)
+NARY_SHAPE_INFER(tosa::SigmoidOp)
+#undef PRED_SHAPE_INFER
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Definitions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 04acbf6425b75..f466b1ab85389 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRTosaTransforms
+  TosaInferShapes.cpp
   TosaMakeBroadcastable.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
new file mode 100644
index 0000000000000..eca63e1e8ab39
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -0,0 +1,247 @@
+//===- TosaInferShapes.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Propogate shapes forward along TOSA operations to resolve dynamic shape
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+// -----------------------------------------------------------------------------
+// Analysis.
+// -----------------------------------------------------------------------------
+
+static Type joinElementTypes(Type lhs, Type rhs) {
+  return lhs == rhs ? lhs : Type();
+}
+
+namespace {
+// Statically known information for a particular Value.
+//
+// This struct currently tracks only information relevant for tensor/array-like
+// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
+// type as long as it is in the default "no knowledge" state returned by
+// `getPessimisticValueState`. The important invariant is that we cannot
+// claim to know something about a value which is false.
+//
+// This class could also be called "dataflow facts", "lattice value", etc.
+struct ValueKnowledge {
+  ValueKnowledge() = delete;
+  ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
+      : hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
+    assert(sizes.size() == 0 || hasSizes);
+  }
+
+  // Get the static knowledge intrinsic to `type`.
+  static ValueKnowledge getKnowledgeFromType(Type type) {
+    ValueKnowledge result = getPessimisticValueState(type.getContext());
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (shapedType.hasRank()) {
+        result.hasSizes = true;
+        result.sizes = shapedType.getShape();
+      }
+      result.dtype = shapedType.getElementType();
+    }
+    return result;
+  }
+
+  // Return a pessimistic/conservative value state without assuming any knowlege
+  // about the IR.
+  static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
+    return ValueKnowledge(false, {}, Type());
+  }
+
+  Type getType() const {
+    if (hasSizes) {
+      return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
+    }
+    return UnrankedTensorType::get(dtype);
+  }
+
+  bool operator==(const ValueKnowledge &rhs) const {
+    return std::make_tuple(hasSizes, sizes, dtype) ==
+           std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
+  }
+
+  // Given two pieces of static knowledge, calculate conservatively the
+  // information we can be sure about.
+  static ValueKnowledge join(const ValueKnowledge &lhs,
+                             const ValueKnowledge &rhs) {
+    // Mental model: All conditions are checking how to change from the safe "no
+    // knowledge" default-initialized state to a state with more knowledge
+    // consistent with lhs and rhs.
+    ValueKnowledge result = getPessimisticValueState(nullptr);
+
+    if (lhs.hasSizes && !rhs.hasSizes) {
+      result.hasSizes = true;
+      result.sizes = lhs.sizes;
+    } else if (!lhs.hasSizes && rhs.hasSizes) {
+      result.hasSizes = true;
+      result.sizes = rhs.sizes;
+    } else if (lhs.hasSizes && rhs.hasSizes &&
+               lhs.sizes.size() == rhs.sizes.size()) {
+      result.hasSizes = true;
+      result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
+      for (int i = 0, e = result.sizes.size(); i != e; i++) {
+        int64_t lhsSize = lhs.sizes[i];
+        int64_t rhsSize = rhs.sizes[i];
+        int64_t &resultSize = result.sizes[i];
+        if (lhsSize == ShapedType::kDynamicSize) {
+          resultSize = rhsSize;
+        } else if (rhsSize == ShapedType::kDynamicSize) {
+          resultSize = lhsSize;
+        } else if (lhsSize == rhsSize) {
+          resultSize = lhsSize;
+        }
+      }
+    }
+
+    result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
+    return result;
+  }
+
+  // Whether the Value is known to have a list of sizes.
+  bool hasSizes;
+  // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
+  // `ShapedType::kDynamicSize`.
+  std::vector<int64_t> sizes;
+  // The dtype of a tensor.
+  // This is equal to nullptr if we don't know that it is a specific concrete
+  // type.
+  Type dtype;
+};
+
+} // namespace
+
+/// Pass that enables broadcast by making all input arrays have the same
+/// number of dimensions. Insert RESHAPE operations to lower rank operand
+struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
+public:
+  void runOnFunction() override {
+    FuncOp func = getOperation();
+
+    IRRewriter rewriter(func.getContext());
+
+    func.walk([&](Operation *op) {
+      if (op->getDialect()->getNamespace() !=
+          tosa::TosaDialect::getDialectNamespace())
+        return;
+      InferShapedTypeOpInterface shapeInterface =
+          dyn_cast<InferShapedTypeOpInterface>(op);
+      if (!shapeInterface)
+        return;
+
+      SmallVector<ShapedTypeComponents> returnedShapes;
+      if (shapeInterface
+              .inferReturnTypeComponents(
+                  op->getContext(), op->getLoc(), op->getOperands(),
+                  op->getAttrDictionary(), op->getRegions(), returnedShapes)
+              .succeeded()) {
+        for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
+          Value result = std::get<0>(it);
+          ShapedTypeComponents predictedShape = std::get<1>(it);
+
+          // Check whether this use case is replaceable. We define an op as
+          // being replaceable if it is used by a ReturnOp or a TosaOp.
+          bool replaceable = true;
+          for (auto user : result.getUsers()) {
+            if (isa<ReturnOp>(user))
+              continue;
+            if (user->getDialect()->getNamespace() ==
+                tosa::TosaDialect::getDialectNamespace())
+              continue;
+
+            replaceable = false;
+          }
+
+          // Determine the knowledge based on the output type.
+          Type resultTy = result.getType();
+          auto currentKnowledge =
+              ValueKnowledge::getKnowledgeFromType(resultTy);
+
+          // Compute the knowledge based on the inferred type.
+          auto inferredKnowledge =
+              ValueKnowledge::getPessimisticValueState(op->getContext());
+          inferredKnowledge.dtype =
+              resultTy.cast<ShapedType>().getElementType();
+          inferredKnowledge.hasSizes = predictedShape.hasRank();
+          if (predictedShape.hasRank()) {
+            for (auto dim : predictedShape.getDims()) {
+              inferredKnowledge.sizes.push_back(dim);
+            }
+          }
+
+          if (!replaceable)
+            continue;
+
+          // Compute the new type based on the joined version.
+          auto newKnowledge =
+              ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+          result.setType(newKnowledge.getType());
+        }
+      }
+    });
+
+    // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
+    // the FuncOp type.
+    func.walk([&](ReturnOp op) {
+      FuncOp parent = dyn_cast<FuncOp>(op->getParentOp());
+      if (!parent)
+        return;
+
+      rewriter.setInsertionPoint(op);
+      FunctionType funcTy = func.getType();
+      auto resultTys = funcTy.getResults();
+
+      bool castAdded = false;
+      SmallVector<Value> castedValues;
+      for (auto it : llvm::zip(op->getOperands(), resultTys)) {
+        auto operand = std::get<0>(it);
+        auto currentTy = operand.getType();
+        auto castTy = std::get<1>(it);
+        if (currentTy == castTy) {
+          castedValues.push_back(operand);
+          continue;
+        }
+
+        castedValues.push_back(
+            rewriter.create<tensor::CastOp>(op.getLoc(), castTy, operand)
+                .getResult());
+
+        castAdded = true;
+      }
+
+      if (castAdded) {
+        rewriter.replaceOpWithNewOp<ReturnOp>(op, castedValues);
+      }
+    });
+  }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
+  return std::make_unique<TosaInferShapes>();
+}

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 60bc6357eb3c8..e850e1f517d2f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR//TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6f381be8e7202..44cfda613d313 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -500,7 +500,7 @@ func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
 // CHECK-LABEL: @test_reshape_downrank_6D
 func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
   // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]]
-  %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = [6, 5, 77]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
   return %0 : tensor<6x5x77xf32>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir
new file mode 100644
index 0000000000000..e73c79cb3ceef
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir
@@ -0,0 +1,278 @@
+// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s
+
+// CHECK-LABEL: @test_return
+func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
+  // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_multiple
+func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
+  // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32>
+
+  // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_f32
+func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
+  // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
+  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32>
+  %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32>
+  %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor<?xf32>
+
+  // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_i32
+func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
+  // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
+  %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+  %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32>
+  %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32>
+  %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor<?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_unary_i1
+func @test_unary_i1(%arg0 : tensor<4xi1>) -> () {
+  // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1>
+  %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_scalar_f32
+func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<4xi1>
+  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<f32>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_broadcast_f32
+func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1>
+  %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_i32
+func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
+  // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi1>
+  %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi1>
+
+  // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK:  "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_binary_i1
+func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor<i1>) -> () {
+  // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
+  %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<4xi1>
+  %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor<i1>) -> tensor<*4i1>
+  %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor<i1>) -> tensor<*xi1>
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_select_i32
+func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor<i32>, %arg2 : tensor<4xi32>) -> () {
+  // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<4xi32>
+  %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor<i32>, tensor<4xi32>) -> tensor<*xi32>
+
+  return
+}
+
+// -----
+
+func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>)  -> tensor<?xi32>
+
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>)  -> tensor<?xi32>
+
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32>
+  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>)  -> tensor<?x?xi32>
+
+  return
+}
+// -----
+
+func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
+  // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32>
+  %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>)  -> tensor<?xi32>
+
+  // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor<?xi32>
+  %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>)  -> tensor<?xi32>
+
+  // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32>
+  %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>)  -> tensor<?x?xi32>
+
+  return
+}
+

diff  --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index a9bb40c76db5a..99bf14b6b59fd 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -11,7 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"


        


More information about the Mlir-commits mailing list