[Mlir-commits] [mlir] b2d76a0 - TOSA-to-Linalg lowering for element-wise ops

Eric Kunze llvmlistbot at llvm.org
Fri Jul 21 15:09:42 PDT 2023


Author: Rafael Ubal Tena
Date: 2023-07-21T22:08:33Z
New Revision: b2d76a063dd7fb681c98a10d8e7f54fd6d25dd27

URL: https://github.com/llvm/llvm-project/commit/b2d76a063dd7fb681c98a10d8e7f54fd6d25dd27
DIFF: https://github.com/llvm/llvm-project/commit/b2d76a063dd7fb681c98a10d8e7f54fd6d25dd27.diff

LOG: TOSA-to-Linalg lowering for element-wise ops

- Wrote complete documentation for the `Broadcastable` op trait. This is mostly meant as a thorough description of its previous behavior, with the exception of minor feature updates.

- Restricted legality criteria for a `Broadcastable` op in order to simplify current and future lowering passes and increase efficiency of code generated by those passes. New restriction are: 1) A dynamic dimension in an inferred result is not compatible with a static dimension in the actual result. 2) Broadcast semantics are restricted to input operands and not supported between inferred and actual result shapes.

- Implemented TOSA-to-Linalg lowering support for unary, binary, tertiary element-wise ops. This support is complete for all legal cases described in the `Broadcastable` trait documentation.

- Added unit tests for `tosa.abs`, `tosa.add`, and `tosa.select` as examples of unary, binary, and tertiary ops.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D153291

Added: 
    mlir/docs/Traits/Broadcastable.md
    mlir/docs/Traits/_index.md

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Traits.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/traits.mlir

Removed: 
    mlir/docs/Traits.md


################################################################################
diff  --git a/mlir/docs/Traits/Broadcastable.md b/mlir/docs/Traits/Broadcastable.md
new file mode 100644
index 00000000000000..3dd0c3998c2ad3
--- /dev/null
+++ b/mlir/docs/Traits/Broadcastable.md
@@ -0,0 +1,197 @@
+# The `Broadcastable` Trait
+
+[TOC]
+
+## Description
+
+The `Broadcastable` trait enforces the following properties on an operation:
+
+- The operation has at least one input operand.
+
+- The operation has exactly one result.
+
+- All input operands and result are of type `tensor` or `vector`.
+
+- A shape inference mechanism is able to compute the result shape solely based on input operand shapes.
+
+- Input operands have broadcast-compatible shapes, according to the verification rules presented below.
+
+- The operation's result shape is compatible with —though not necessarily identical to— the shape inferred from its input operands, according to the verification rules presented below.
+
+
+## Dimension inference
+
+Given an operation with two input operands, the size of dimension `i` of its result can be inferred from dimension `i` of the operands according to the table below. Here, `dim0` and `dim1` represent dimension `i` of the input operands in an interchangeable order, while `inferredDim` represents the inferred size for dimension `i` of the operation result. Dimensions are classified in three categories: dynamic ("?"), static equal to 1 ("1"), and static greater than 1 (">1").
+
+
+| `dim0` | `dim1` | `inferredDim` | Notes |
+| -------- | -------- | ------------- | ----- |
+| ? | ? | ? | If `RuntimeSize(dim0)` is 1, dimension `dim0` is broadcast to `RuntimeSize(dim1)`. If `RuntimeSize(dim1)` is 1, dimension `dim1` is broadcast to `RuntimeSize(dim0)`. The operation produces undefined behavior if both runtime sizes are greater than 1 and not equal. |
+| ? | 1 | ? | Dimension `dim1` is broadcast to `RuntimeSize(dim0)`. |
+| ? | >1 | `dim1` | If `RuntimeSize(dim0)` is 1, `dim0` is broadcast to `dim1`. The operation produces undefined behavior if `RuntimeSize(dim0)` is greater than 1 and not equal to `dim1`. |
+| 1 | 1 | 1 | |
+| 1 | >1 | `dim1` | Dimension `dim0` is broadcast to `dim1`. |
+| >1 | >1 | `dim0` | The operation verifier produces a compile-time error if `dim0` != `dim1`. |
+
+
+The following pseudo-function is a formal representation of the dimension inference process:
+
+```python
+InferDim(dim0, dim1):
+	switch (dim0, dim1):
+		case (?, ?):
+		case (?, 1):
+		case (1, 1):
+		case (>1, ?):
+		case (>1, 1):
+			return dim0
+		case (?, >1):
+		case (1, ?):
+		case (1, >1):
+			return dim1
+		case (>1, >1):
+			ERROR_IF(dim0 != dim1)
+			return dim0
+```
+
+## Shape inference
+
+The shape inference process begins by correcting rank 
diff erences in input operands. A shape is expanded by adding additional dimensions of size 1 on its left until the desired rank is reached, as shown here:
+
+```python
+ExpandRank(shape, rank):
+	while len(shape) < rank:
+		shape.prepend(1)
+```
+		
+Given the shapes of two ranked input operands, the result's shape is inferred by equalizing input ranks and inferring individual dimensions, as shown here:
+
+```python
+InferShape(shape0, shape1):
+
+	# Equalize ranks
+	rank = max(GetRank(shape0), GetRank(shape1))
+	ExpandRank(shape0, rank)
+	ExpandRank(shape1, rank)
+	
+	# Infer shape
+	inferredShape = []
+	for (dim0, dim1) in zip(shape0, shape1):
+		inferredDim = InferDim(dim0, dim1)
+        inferredShape.append(inferredDim)
+	return inferredShape
+```
+	
+The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:
+
+```python
+InferResultShape(op):
+
+	# Filter ranked operands
+	rankedOperands = filter(op.operands, IsRanked)
+	if len(rankedOperands) == 0:
+		return None
+	
+	# Infer result shape
+	inferredShape = GetShape(rankedOperands[0])
+	for operand in rankedOperands[1:]:
+		inferredShape = InferShape(inferredShape, GetShape(operand))
+	return inferredShape
+```
+
+## Verification
+
+The legality of an operation with the `Broadcastable` trait is verified by first running the shape inference process. If a failure occurs during shape inference, it is concluded that input operands are not broadcast-compatible, and verification fails. If shape inference succeeds, verification continues.
+
+If either the result is unranked or all input operands are unranked, no further verification steps are needed, and the process ends here successfully. If, on the contrary, both the result and at least one input operand are ranked, verification continues by checking for a matching rank between the previously inferred shape and the result.
+
+Once a rank match is guaranteed, each dimension of the inferred shape is compared with the corresponding dimension of the actual result shape according to the following table table:
+
+
+| `inferredDim` | `actualDim` | Verification outcome |
+| ------------- | ----------- | -------------------- |
+| ? | ? | **OK** |
+| ? | static | **Error** <br> An inferred dimension being dynamic indicates that its size cannot be inferred at compile time from its input operands. The presence of a static dimension in the actual result is counterintuitive and is therefore not allowed. |
+| static | ? | **OK** <br> The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
+| static | static | **OK if equal** <br> When both the inferred and actual dimensions are static, they must be set to the same size. |
+
+
+The full verification process can be formally specified as follows:
+
+```python
+Verify(op):
+
+	# Run shape inference
+	inferredShape = InferResultShape(op.operands)
+
+	# Done if result is unranked or all operands are unranked
+	if not IsRanked(op.result) or inferredShape is None:
+		return
+	
+	# Rank must match
+	actualShape = GetShape(op.result):
+	ERROR_IF(len(inferredShape) != len(actualShape))
+	
+	# Verify
+	for (inferredDim, actualDim) in zip(inferredShape, actualShape):
+		ERROR_IF(IsDynamic(inferredDim) and IsStatic(actualDim))
+		ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
+```
+		
+## Examples
+
+The following are correct uses of broadcastable ops:
+
+```mlir
+// Exact match of static sizes.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xi32) -> tensor<1x2xi32>
+
+// Dynamic sizes match. The programmer must guarantee that the runtime sizes of
+// %arg0 and %arg1 are equal at runtime.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<?xi32>
+
+// The shape of %arg0 is broadcast from tensor<1xi32> to tensor<4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<4xi32) -> tensor<4xi32>
+
+// The shape of %result is inferred as tensor<4xi32>, while the actual result
+// type is tensor<?xi32>. The inferred shape is compatible with the actual shape.
+%result = "test.broadcastable"(%arg0) : (tensor<4xi32) -> tensor<?xi32>
+
+// The shape of %arg0 is first expanded to tensor<1x1x4xi32> and then broadcast
+// to tensor<2x3x4xi32>.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x3x4xi32) -> tensor<2x3x4xi32>
+
+// Input and results tensors have 
diff erent element types (i1, i32, i64). The
+// 'Broadcastable' trait has no restrictions on element types.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi32) -> tensor<2xi64>
+
+// No result shape verification is needed when the result is unranked.
+%result = "test.broadcastable"(%arg0) : (tensor<2xi32>) -> tensor<*xi32>
+
+// No result shape verification needed when all inputs are unranked.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
+```
+
+
+The following are incorrect uses of broadcastable ops:
+
+```mlir
+// Dimension 0 of input operands is static but not equal.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32) -> tensor<?xi32>
+
+// The inferred result shape is tensor<3xi32>, but the actual result shape is
+// tensor<1x3xi32>. Inferred and actual shapes 
diff er in rank.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32) -> tensor<1x3xi32>
+
+// The inferred result shape is tensor<?xi32>, but the actual shape is
+// tensor<4xi32>. The inferred shape is not compatible with the actual shape.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<4xi32>
+
+// The inferred result shape is tensor<2xi32>, but the actual result shape is
+// tensor<4xi32>, which is not compatible.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32) -> tensor<4xi32>
+
+// The inferred result shape is tensor<1xi32>, but the actual result shape is
+// tensor<4xi32>. Broadcast semantics are not applicable for results.
+%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
+```

diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits/_index.md
similarity index 95%
rename from mlir/docs/Traits.md
rename to mlir/docs/Traits/_index.md
index 74ab7784c9ab6d..6a9c650aca96b7 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits/_index.md
@@ -241,16 +241,7 @@ that has the trait AutomaticAllocationScope.
 
 This trait adds the property that the operation is known to have
 [broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-operands and its result types' shape is the broadcast compatible with the shape
-of the broadcasted operands. Specifically, starting from the most varying
-dimension, each dimension pair of the two operands' shapes should either be the
-same or one of them is one. Also, the result shape should have the corresponding
-dimension equal to the larger one, if known. Shapes are checked partially if
-ranks or dimensions are not known. For example, an op with `tensor<?x2xf32>` and
-`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
-broadcast-compatible.
-
-This trait requires that the operands are either vector or tensor types.
+operands and that its result type is compatible with the inferred broadcast shape. See [The `Broadcastable` Trait](Traits/Broadcastable.md) for details.
 
 ### Commutative
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index c38580b0378e42..a6a085468ac19a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -215,18 +215,12 @@ class Tosa_Op<string mnemonic, list<Trait> traits = []> :
     Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
 }
 
-class Tosa_ElemWiseUnaryOp<string mnemonic, list<Trait> traits = []> :
+class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
     Tosa_Op<mnemonic, !listconcat(traits, [
               DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                                         ["inferReturnTypeComponents"]>,
-              Pure, SameOperandsAndResultElementType])> {
-}
-
-class Tosa_ElemWiseBinaryOp<string mnemonic, list<Trait> traits = []> :
-    Tosa_Op<mnemonic, !listconcat(traits, [
-              DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                                        ["inferReturnTypeComponents"]>,
-              ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> {
+              ResultsBroadcastableShape,
+              Pure])> {
 }
 
 #endif // TOSA_OP_BASE

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 421d6b09424a59..812db606128a25 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -345,7 +345,7 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d",
 //===----------------------------------------------------------------------===//
 // Operator: clamp
 //===----------------------------------------------------------------------===//
-def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> {
+def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
   let summary = "Computes clamp(features, min, max).";
 
   let description = [{
@@ -374,7 +374,7 @@ def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> {
 //===----------------------------------------------------------------------===//
 // Operator: sigmoid
 //===----------------------------------------------------------------------===//
-def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> {
+def Tosa_SigmoidOp : Tosa_ElementwiseOp<"sigmoid"> {
   let summary = "Computes elementwise sigmoid of input.";
 
   let description = [{
@@ -397,7 +397,7 @@ def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> {
 //===----------------------------------------------------------------------===//
 // Operator: tanh
 //===----------------------------------------------------------------------===//
-def Tosa_TanhOp : Tosa_ElemWiseUnaryOp<"tanh"> {
+def Tosa_TanhOp : Tosa_ElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
   let summary = "Computes elementwise hyperbolic tangent of input";
 
   let description = [{
@@ -451,7 +451,9 @@ def Tosa_ErfOp : Tosa_Op<"erf", [
 //===----------------------------------------------------------------------===//
 // Operator: add
 //===----------------------------------------------------------------------===//
-def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> {
+def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise addition operator";
 
   let description = [{
@@ -474,7 +476,8 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: arithmetic_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> {
+def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Arithmetic Right Shift";
 
   let description = [{
@@ -496,7 +499,9 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_and
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> {
+def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise AND operator";
 
   let description = [{
@@ -517,7 +522,9 @@ def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_or
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> {
+def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise OR operator";
 
   let description = [{
@@ -538,7 +545,9 @@ def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_xor
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> {
+def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Bitwise XOR operator";
 
   let description = [{
@@ -559,7 +568,7 @@ def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: div
 //===----------------------------------------------------------------------===//
-def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> {
+def Tosa_DivOp : Tosa_ElementwiseOp<"div", [SameOperandsAndResultElementType]> {
   let summary = "Integer divide operator";
 
   let description = [{
@@ -582,7 +591,9 @@ def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_and
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> {
+def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x AND y element-wise.";
 
   let description = [{
@@ -603,7 +614,8 @@ def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_left_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> {
+def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Logical Left Shift";
 
   let description = [{
@@ -624,7 +636,8 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_right_shift
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> {
+def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise Logical Right Shift";
 
   let description = [{
@@ -645,7 +658,9 @@ def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_or
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> {
+def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x OR y element-wise.";
 
   let description = [{
@@ -666,7 +681,9 @@ def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_xor
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> {
+def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of x XOR y element-wise.";
 
   let description = [{
@@ -687,7 +704,9 @@ def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: maximum
 //===----------------------------------------------------------------------===//
-def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> {
+def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise Maximum";
 
   let description = [{
@@ -708,7 +727,9 @@ def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: minimum
 //===----------------------------------------------------------------------===//
-def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> {
+def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Elementwise Minimum";
 
   let description = [{
@@ -729,7 +750,9 @@ def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: mul
 //===----------------------------------------------------------------------===//
-def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> {
+def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
+    Commutative,
+    SameOperandsAndResultElementType]> {
   let summary = "Multiplication operator";
 
   let description = [{
@@ -754,7 +777,7 @@ def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> {
 //===----------------------------------------------------------------------===//
 // Operator: pow
 //===----------------------------------------------------------------------===//
-def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
+def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
   let summary = "Computes the power of one value to another.";
 
   let description = [{
@@ -775,7 +798,7 @@ def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> {
 //===----------------------------------------------------------------------===//
 // Operator: sub
 //===----------------------------------------------------------------------===//
-def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
+def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise subtraction operator";
 
   let description = [{
@@ -838,7 +861,7 @@ def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
 //===----------------------------------------------------------------------===//
 // Operator: abs
 //===----------------------------------------------------------------------===//
-def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> {
+def Tosa_AbsOp : Tosa_ElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise abs op";
 
   let description = [{
@@ -859,7 +882,8 @@ def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> {
 //===----------------------------------------------------------------------===//
 // Operator: bitwise_not
 //===----------------------------------------------------------------------===//
-def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> {
+def Tosa_BitwiseNotOp : Tosa_ElementwiseOp<"bitwise_not",
+    [SameOperandsAndResultElementType]> {
   let summary = "Bitwise NOT operator";
 
   let description = [{
@@ -878,7 +902,7 @@ def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> {
 //===----------------------------------------------------------------------===//
 // Operator: ceil
 //===----------------------------------------------------------------------===//
-def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> {
+def Tosa_CeilOp : Tosa_ElementwiseOp<"ceil", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise ceil op";
 
   let description = [{
@@ -897,7 +921,7 @@ def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> {
 //===----------------------------------------------------------------------===//
 // Operator: clz
 //===----------------------------------------------------------------------===//
-def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> {
+def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise count leading zero op";
 
   let description = [{
@@ -916,7 +940,7 @@ def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> {
 //===----------------------------------------------------------------------===//
 // Operator: exp
 //===----------------------------------------------------------------------===//
-def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> {
+def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise exp op";
 
   let description = [{
@@ -937,7 +961,7 @@ def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> {
 //===----------------------------------------------------------------------===//
 // Operator: floor
 //===----------------------------------------------------------------------===//
-def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> {
+def Tosa_FloorOp : Tosa_ElementwiseOp<"floor", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise floor op";
 
   let description = [{
@@ -956,7 +980,7 @@ def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> {
 //===----------------------------------------------------------------------===//
 // Operator: log
 //===----------------------------------------------------------------------===//
-def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> {
+def Tosa_LogOp : Tosa_ElementwiseOp<"log", [SameOperandsAndResultElementType]> {
   let summary = "Elementwise log op";
 
   let description = [{
@@ -977,7 +1001,8 @@ def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> {
 //===----------------------------------------------------------------------===//
 // Operator: logical_not
 //===----------------------------------------------------------------------===//
-def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> {
+def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
+    [SameOperandsAndResultElementType]> {
   let summary = "Returns the truth value of NOT x element-wise.";
 
   let description = [{
@@ -996,7 +1021,8 @@ def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> {
 //===----------------------------------------------------------------------===//
 // Operator: negate
 //===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_ElementwiseOp<"negate",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise negate op";
 
   let description = [{
@@ -1020,7 +1046,8 @@ def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> {
 //===----------------------------------------------------------------------===//
 // Operator: reciprocal
 //===----------------------------------------------------------------------===//
-def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> {
+def Tosa_ReciprocalOp : Tosa_ElementwiseOp<"reciprocal",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise reciprocal op";
 
   let description = [{
@@ -1040,7 +1067,8 @@ def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> {
 //===----------------------------------------------------------------------===//
 // Operator: rsqrt
 //===----------------------------------------------------------------------===//
-def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> {
+def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
+    [SameOperandsAndResultElementType]> {
   let summary = "Elementwise 1/sqrt op";
 
   let description = [{
@@ -1066,9 +1094,7 @@ def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> {
 //===----------------------------------------------------------------------===//
 // Operator: select
 //===----------------------------------------------------------------------===//
-def Tosa_SelectOp : Tosa_Op<"select", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>, Pure]> {
+def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   let summary = "Elementwise select operator";
 
   let description = [{
@@ -1096,8 +1122,10 @@ def Tosa_SelectOp : Tosa_Op<"select", [
 //===----------------------------------------------------------------------===//
 // Operator: equal
 //===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
-    Commutative, Pure, SameOperandsElementType]> {
+def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
+    InferTensorType,
+    Commutative,
+    SameOperandsElementType]> {
   let summary = "Returns the truth value of (x == y) element-wise.";
 
   let description = [{
@@ -1125,10 +1153,7 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
 //===----------------------------------------------------------------------===//
 // Operator: greater
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterOp : Tosa_Op<"greater", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
+def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
   let summary = "Returns the truth value of (x > y) element-wise.";
 
   let description = [{
@@ -1150,10 +1175,8 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [
 //===----------------------------------------------------------------------===//
 // Operator: greater_equal
 //===----------------------------------------------------------------------===//
-def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    ResultsBroadcastableShape, Pure, SameOperandsElementType]> {
+def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
+    [SameOperandsElementType]> {
   let summary = "Returns the truth value of (x >= y) element-wise.";
 
   let description = [{

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 49c8cf4e2ddd54..bfd08ad389610a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -24,9 +24,12 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
 
 #include <numeric>
 
@@ -517,115 +520,339 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   return nullptr;
 }
 
-static LogicalResult
-elementwiseMatchAndRewriteHelper(Operation *operation,
-                                 PatternRewriter &rewriter) {
-  auto loc = operation->getLoc();
-
-  assert(operation->getNumResults() == 1 &&
-         "All TOSA elementwise ops should only return a single result.");
-
-  auto result = operation->getResult(0);
-  auto resultTy = dyn_cast<RankedTensorType>(result.getType());
-
-  if (!resultTy)
-    return rewriter.notifyMatchFailure(
-        operation, "All results must be a ranked tensor type");
+static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
+                        int64_t rank) {
+  // No need to expand if we are already at the desired rank
+  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
+  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+  int64_t numExtraDims = rank - shapedType.getRank();
+  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
+  if (!numExtraDims)
+    return tensor;
+
+  // Compute reassociation indices
+  SmallVector<SmallVector<int64_t, 2>> reassociationIndices(
+      shapedType.getRank());
+  int64_t index = 0;
+  for (index = 0; index <= numExtraDims; index++)
+    reassociationIndices[0].push_back(index);
+  for (size_t position = 1; position < reassociationIndices.size(); position++)
+    reassociationIndices[position].push_back(index++);
+
+  // Compute result type
+  SmallVector<int64_t> resultShape;
+  for (index = 0; index < numExtraDims; index++)
+    resultShape.push_back(1);
+  for (auto size : shapedType.getShape())
+    resultShape.push_back(size);
+  auto resultType =
+      RankedTensorType::get(resultShape, shapedType.getElementType());
+
+  // Emit 'tensor.expand_shape' op
+  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
+                                                reassociationIndices);
+}
 
-  unsigned rank = resultTy.getRank();
+static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
+                                           Location loc, Operation *operation) {
+  auto rank =
+      operation->getResultTypes().front().cast<RankedTensorType>().getRank();
+  return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
+    return expandRank(rewriter, loc, operand, rank);
+  });
+}
 
-  // Construct the indexing maps needed for linalg.generic ops.
-  SmallVector<Type> bodyArgTypes;
+using IndexPool = DenseMap<int64_t, Value>;
+
+// Emit an 'arith.constant' op for the given index if it has not been created
+// yet, or return an existing constant. This will prevent an excessive creation
+// of redundant constants, easing readability of emitted code for unit tests.
+static Value createIndex(PatternRewriter &rewriter, Location loc,
+                         IndexPool &indexPool, int64_t index) {
+  auto [it, inserted] = indexPool.try_emplace(index);
+  if (inserted)
+    it->second =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
+  return it->second;
+}
 
-  for (Value in : operation->getOperands())
-    bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
+static Value getTensorDim(PatternRewriter &rewriter, Location loc,
+                          IndexPool &indexPool, Value tensor, int64_t index) {
+  auto indexValue = createIndex(rewriter, loc, indexPool, index);
+  return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
+}
 
-  SmallVector<Type> opResultTypes;
-  SmallVector<Value> emptyTensors;
+static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
+                                       IndexPool &indexPool, Value tensor,
+                                       int64_t index) {
+  auto shapedType = dyn_cast<ShapedType>(tensor.getType());
+  assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
+  assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
+  if (shapedType.isDynamicDim(index))
+    return getTensorDim(rewriter, loc, indexPool, tensor, index);
+  return rewriter.getIndexAttr(shapedType.getDimSize(index));
+}
 
-  SmallVector<Value> dynDims;
-  dynDims.resize(rank);
+static bool operandsAndResultsRanked(Operation *operation) {
+  auto isRanked = [](Value value) {
+    return isa<RankedTensorType>(value.getType());
+  };
+  return llvm::all_of(operation->getOperands(), isRanked) &&
+         llvm::all_of(operation->getResults(), isRanked);
+}
 
-  for (auto arg : operation->getOperands()) {
-    auto operandTy = cast<ShapedType>(arg.getType());
-    for (int i = 0; i < operandTy.getRank(); i++) {
-      if (operandTy.isDynamicDim(i) && !dynDims[i])
-        dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
-    }
+// Compute the runtime dimension size for dimension 'dim' of the output by
+// inspecting input 'operands', all of which are expected to have the same rank.
+// This function returns a pair {targetSize, masterOperand}.
+//
+// The runtime size of the output dimension is returned either as a statically
+// computed attribute or as a runtime SSA value.
+//
+// If the target size was inferred directly from one dominating operand, that
+// operand is returned in 'masterOperand'. If the target size is inferred from
+// multiple operands, 'masterOperand' is set to nullptr.
+static std::pair<OpFoldResult, Value>
+computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
+                  ValueRange operands, int64_t dim) {
+  // If any input operand contains a static size greater than 1 for this
+  // dimension, that is the target size. An occurrence of an additional static
+  // dimension greater than 1 with a 
diff erent value is undefined behavior.
+  for (auto operand : operands) {
+    auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
+    if (!ShapedType::isDynamic(size) && size > 1)
+      return {rewriter.getIndexAttr(size), operand};
   }
 
-  SmallVector<Value> filteredDims = condenseValues(dynDims);
-
-  emptyTensors.push_back(
-      rewriter.create<tensor::EmptyOp>(loc, resultTy, filteredDims));
-  opResultTypes.push_back(result.getType());
-
-  auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
-      emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));
-
-  SmallVector<Value, 2> operands;
-  SmallVector<AffineMap, 2> indexingMaps;
-  indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
-
-  // Input indexing maps may be broadcasted.
-  for (Value operand : operation->getOperands()) {
-    ShapedType type = cast<ShapedType>(operand.getType());
+  // Filter operands with dynamic dimension
+  auto operandsWithDynamicDim =
+      llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
+        return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
+      }));
+
+  // If no operand has a dynamic dimension, it means all sizes were 1
+  if (operandsWithDynamicDim.empty())
+    return {rewriter.getIndexAttr(1), operands.front()};
+
+  // Emit code that computes the runtime size for this dimension. If there is
+  // only one operand with a dynamic dimension, it is considered the master
+  // operand that determines the runtime size of the output dimension.
+  auto targetSize =
+      getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
+  if (operandsWithDynamicDim.size() == 1)
+    return {targetSize, operandsWithDynamicDim[0]};
+
+  // Calculate maximum size among all dynamic dimensions
+  for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
+    auto nextSize =
+        getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
+    targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
+  }
+  return {targetSize, nullptr};
+}
 
-    if (type.getShape() == resultTy.getShape()) {
-      operands.push_back(operand);
-      indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
-      continue;
-    }
+// Compute the runtime output size for all dimensions. This function returns
+// a pair {targetShape, masterOperands}.
+static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
+computeTargetShape(PatternRewriter &rewriter, Location loc,
+                   IndexPool &indexPool, ValueRange operands) {
+  assert(!operands.empty());
+  auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
+  SmallVector<OpFoldResult> targetShape;
+  SmallVector<Value> masterOperands;
+  for (auto dim : llvm::seq<int64_t>(0, rank)) {
+    auto [targetSize, masterOperand] =
+        computeTargetSize(rewriter, loc, indexPool, operands, dim);
+    targetShape.push_back(targetSize);
+    masterOperands.push_back(masterOperand);
+  }
+  return {targetShape, masterOperands};
+}
 
-    SmallVector<int64_t, 5> newShape;
-    SmallVector<AffineExpr, 4> affineExprs;
-    newShape.reserve(type.getRank());
-    for (const auto &it : llvm::enumerate(type.getShape())) {
-      if (it.value() == resultTy.getDimSize(it.index())) {
-        newShape.push_back(it.value());
-        affineExprs.push_back(
-            mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
-      }
+static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
+                                       IndexPool &indexPool, Value operand,
+                                       int64_t dim, OpFoldResult targetSize,
+                                       Value masterOperand) {
+  // Nothing to do if this is a static dimension
+  auto rankedTensorType = operand.getType().cast<RankedTensorType>();
+  if (!rankedTensorType.isDynamicDim(dim))
+    return operand;
+
+  // If the target size for this dimension was directly inferred by only taking
+  // this operand into account, there is no need to broadcast. This is an
+  // optimization that will prevent redundant control flow, and constitutes the
+  // main motivation for tracking "master operands".
+  if (operand == masterOperand)
+    return operand;
+
+  // Affine maps for 'linalg.generic' op
+  auto rank = rankedTensorType.getRank();
+  SmallVector<AffineExpr> affineExprs;
+  for (auto index : llvm::seq<int64_t>(0, rank)) {
+    auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
+                                   : rewriter.getAffineDimExpr(index);
+    affineExprs.push_back(affineExpr);
+  }
+  auto broadcastAffineMap =
+      AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
+  auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
+  SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
+
+  // Check if broadcast is necessary
+  auto one = createIndex(rewriter, loc, indexPool, 1);
+  auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
+  auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::eq, runtimeSize, one);
+
+  // Emit 'then' region of 'scf.if'
+  auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+    // Emit 'tensor.empty' op
+    SmallVector<OpFoldResult> outputTensorShape;
+    for (auto index : llvm::seq<int64_t>(0, rank)) {
+      auto size = index == dim ? targetSize
+                               : getOrFoldTensorDim(rewriter, loc, indexPool,
+                                                    operand, index);
+      outputTensorShape.push_back(size);
     }
+    Value outputTensor = opBuilder.create<tensor::EmptyOp>(
+        loc, outputTensorShape, rankedTensorType.getElementType());
+
+    // Emit 'linalg.generic' op
+    auto resultTensor =
+        opBuilder
+            .create<linalg::GenericOp>(
+                loc, outputTensor.getType(), operand, outputTensor, affineMaps,
+                getNParallelLoopsAttrs(rank),
+                [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+                  // Emit 'linalg.yield' op
+                  opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
+                })
+            .getResult(0);
+
+    // Cast to original operand type if necessary
+    auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
+        loc, operand.getType(), resultTensor);
+
+    // Emit 'scf.yield' op
+    opBuilder.create<scf::YieldOp>(loc, castResultTensor);
+  };
+
+  // Emit 'else' region of 'scf.if'
+  auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
+    opBuilder.create<scf::YieldOp>(loc, operand);
+  };
+
+  // Emit 'scf.if' op
+  auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
+                                         emitThenRegion, emitElseRegion);
+  return ifOp.getResult(0);
+}
 
-    if (newShape.size() != rank) {
-      operand = rewriter.create<tosa::ReshapeOp>(
-          loc, RankedTensorType::get(newShape, type.getElementType()), operand,
-          rewriter.getDenseI64ArrayAttr(newShape));
-    }
+static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
+                                        IndexPool &indexPool, Value operand,
+                                        ArrayRef<OpFoldResult> targetShape,
+                                        ArrayRef<Value> masterOperands) {
+  size_t rank = operand.getType().cast<RankedTensorType>().getRank();
+  assert(targetShape.size() == rank);
+  assert(masterOperands.size() == rank);
+  for (auto index : llvm::seq<int64_t>(0, rank))
+    operand =
+        broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
+                                  targetShape[index], masterOperands[index]);
+  return operand;
+}
 
-    operands.push_back(operand);
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/rank, /*symbolCount=*/0, affineExprs,
-        rewriter.getContext()));
-  }
+static SmallVector<Value>
+broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
+                           IndexPool &indexPool, ValueRange operands,
+                           ArrayRef<OpFoldResult> targetShape,
+                           ArrayRef<Value> masterOperands) {
+  // No need to broadcast for unary operations
+  if (operands.size() == 1)
+    return operands;
+
+  // Broadcast dynamic dimensions operand by operand
+  return llvm::map_to_vector(operands, [&](Value operand) {
+    return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
+                                      targetShape, masterOperands);
+  });
+}
 
-  indexingMaps.append(operation->getNumResults(),
-                      rewriter.getMultiDimIdentityMap(rank));
+static LogicalResult
+emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
+                           Operation *operation, ValueRange operands,
+                           ArrayRef<OpFoldResult> targetShape) {
+  // Generate output tensor
+  auto resultType =
+      operation->getResultTypes().front().cast<RankedTensorType>();
+  Value outputTensor = rewriter.create<tensor::EmptyOp>(
+      loc, targetShape, resultType.getElementType());
+
+  // Create affine maps. Input affine maps broadcast static dimensions of size
+  // 1. The output affine map is an identity map.
+  //
+  auto rank = resultType.getRank();
+  auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
+    auto shape = cast<ShapedType>(operand.getType()).getShape();
+    SmallVector<AffineExpr> affineExprs;
+    for (auto it : llvm::enumerate(shape)) {
+      auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0)
+                                        : rewriter.getAffineDimExpr(it.index());
+      affineExprs.push_back(affineExpr);
+    }
+    return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
+  });
+  affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
 
-  bool didEncounterError = false;
+  // Emit 'linalg.generic' op
+  bool encounteredError = false;
   auto linalgOp = rewriter.create<linalg::GenericOp>(
-      loc, opResultTypes, operands, emptyTensors, indexingMaps,
+      loc, outputTensor.getType(), operands, outputTensor, affineMaps,
       getNParallelLoopsAttrs(rank),
-      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+      [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
         Value opResult = createLinalgBodyCalculationForElementwiseOp(
             operation, blockArgs.take_front(operation->getNumOperands()),
-            bodyResultTypes, rewriter);
+            {resultType.getElementType()}, rewriter);
         if (!opResult) {
-          didEncounterError = true;
+          encounteredError = true;
           return;
         }
-        nestedBuilder.create<linalg::YieldOp>(loc, opResult);
+        opBuilder.create<linalg::YieldOp>(loc, opResult);
       });
-
-  if (didEncounterError)
+  if (encounteredError)
     return rewriter.notifyMatchFailure(
         operation, "unable to create linalg.generic body for elementwise op");
 
-  rewriter.replaceOp(operation, linalgOp->getResults());
+  // Cast 'linalg.generic' result into original result type if needed
+  auto castResult = rewriter.createOrFold<tensor::CastOp>(
+      loc, resultType, linalgOp->getResult(0));
+  rewriter.replaceOp(operation, castResult);
   return success();
 }
 
+static LogicalResult
+elementwiseMatchAndRewriteHelper(Operation *operation,
+                                 PatternRewriter &rewriter) {
+
+  // Collect op properties
+  assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
+  assert(operation->getNumOperands() >= 1 &&
+         "elementwise op expects at least 1 operand");
+  if (!operandsAndResultsRanked(operation))
+    return rewriter.notifyMatchFailure(operation,
+                                       "Unranked tensors not supported");
+
+  // Lower operation
+  IndexPool indexPool;
+  auto loc = operation->getLoc();
+  auto expandedOperands = expandInputRanks(rewriter, loc, operation);
+  auto [targetShape, masterOperands] =
+      computeTargetShape(rewriter, loc, indexPool, expandedOperands);
+  auto broadcastOperands = broadcastDynamicDimensions(
+      rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+  return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
+                                    targetShape);
+}
+
 // Returns the constant initial value for a given reduction operation. The
 // attribute type varies depending on the element type required.
 static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
@@ -741,7 +968,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
-  llvm::SmallVector<int64_t> reduceShape;
+  SmallVector<int64_t> reduceShape;
   SmallVector<Value> dynDims;
   for (unsigned i = 0; i < inputTy.getRank(); i++) {
     if (axis != i) {

diff  --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index 2ae67dcfd0bef5..36c6d0d75a083b 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -195,18 +195,22 @@ static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
 
 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
                                             ArrayRef<int64_t> existing) {
-  auto isCompatible = [](int64_t dim1, int64_t dim2) {
-    // If the inferred and existing dim is the same, or one of them is unknown
-    // then it is compatible, else if the inferred dim is 1 then it is also
-    // compatible. But if the existing dim is 1 and the inferred is greater than
-    // 1 then flag.
-    return dim1 == dim2 || ShapedType::isDynamic(dim1) ||
-           ShapedType::isDynamic(dim2) || dim1 == 1;
+  auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
+    // The following criterion is used to determine the validity of an existing
+    // dimension:
+    //
+    // inferredDim  existingDim  Behavior
+    // -----------  -----------  --------
+    // dynamic      dynamic      OK
+    // dynamic      static       Error
+    // static       dynamic      OK
+    // static       static       OK if equal
+    return ShapedType::isDynamic(existingDim) || inferredDim == existingDim;
   };
   if (inferred.size() != existing.size())
     return false;
-  for (auto p : llvm::zip(inferred, existing))
-    if (!isCompatible(std::get<0>(p), std::get<1>(p)))
+  for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+    if (!isCompatible(inferredDim, existingDim))
       return false;
   return true;
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 54c8e574125b7c..1055d6ff6fb784 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -2,162 +2,406 @@
 
 // CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
 
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<f32>) -> tensor<f32> {
+// CHECK-LABEL: @test_abs_scalar
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_scalar(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<f32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins(%[[ARG0]] : tensor<f32>) outs([[INIT]] : tensor<f32>) {
-  // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = math.absf %[[ARG1]]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = []} ins([[ARG0]] : tensor<f32>) outs([[INIT]] : tensor<f32>) {
+  // CHECK:   ^bb0([[ARG1:%.*]]: f32, [[ARG2:%.*]]: f32):
+  // CHECK:   [[ELEMENT:%.*]] = math.absf [[ARG1]] : f32
   // CHECK:   linalg.yield [[ELEMENT]] : f32
   // CHECK: } -> tensor<f32>
+	%0 = "tosa.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
 
-  %0 = "tosa.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
+  // CHECK: return [[GENERIC]] : tensor<f32>
+	return %0 : tensor<f32>
+}
 
-  // CHECK: return [[GENERIC]]
-  return %0 : tensor<f32>
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_abs_1d_cast_result
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<5xf32>
+  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<5xf32>) outs([[EMPTY]] : tensor<5xf32>) {
+  // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
+  // CHECK:   [[ABS:%.+]] = math.absf [[IN0]] : f32
+  // CHECK:   linalg.yield [[ABS]] : f32
+  // CHECK: } -> tensor<5xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor<?xf32>
+
+  // CHECK: [[CAST_RESULT:%.+]] = tensor.cast [[RESULT]] : tensor<5xf32> to tensor<?xf32>
+  // CHECK: return [[CAST_RESULT]] : tensor<?xf32>
+  return %0 : tensor<?xf32>
 }
 
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_abs_1d_dynamic
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @test_abs_1d_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+
+  // CHECK: [[ZERO:%.+]] = arith.constant 0 : index
+  // CHECK: [[DIM:%.+]] = tensor.dim [[ARG0]], [[ZERO]] : tensor<?xf32>
+  // CHECK: [[EMPTY:%.+]] = tensor.empty([[DIM]]) : tensor<?xf32>
+  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xf32>) outs([[EMPTY]] : tensor<?xf32>) {
+  // CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
+  // CHECK:   [[ABSF:%.+]] = math.absf [[IN0]] : f32
+  // CHECK:   linalg.yield [[ABSF]] : f32
+  // CHECK: } -> tensor<?xf32>
+  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
 
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
-  // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = math.absf %[[ARG1]]
-  // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<2xf32>
-  %0 = "tosa.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  // CHECK: return [[RESULT]] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
 
-  // CHECK: return [[GENERIC]]
-  return %0 : tensor<2xf32>
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
+// CHECK-LABEL: @test_add_0d
+// CHECK-SAME: [[ARG0:%[0-9a-zA-Z_]*]]:
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]:
+func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+
+  // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<f32>
+  // CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins([[ARG0]], [[ARG1]] : tensor<f32>, tensor<f32>) outs([[EMPTY]] : tensor<f32>) {
+  // CHECK: ^bb0([[IN0:%.+]]: f32, [[IN1:%.+]]: f32, [[OUT0:%.+]]: f32):
+  // CHECK:   [[ADDF:%.+]] = arith.addf [[IN0]], [[IN1]] : f32
+  // CHECK:   linalg.yield [[ADDF]] : f32
+  // CHECK: } -> tensor<f32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  
+  // CHECK: return [[RESULT]] : tensor<f32>
+  return %0 : tensor<f32>
 }
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_all_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_all_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+
+  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[ARG0_MAX_DIM:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
+  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
+  // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?xf32>) {
+  // CHECK:   %[[VAL_2:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
+  // CHECK:   %[[VAL_3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<?xf32>) outs(%[[VAL_2]] : tensor<?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_4]] : f32
+  // CHECK:   } -> tensor<?xf32>
+  // CHECK:   scf.yield %[[VAL_3]] : tensor<?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG0]] : tensor<?xf32>
+  // CHECK: }
+  // CHECK: %[[VAL_6:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_6]], %[[CONST1]] : index
+  // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_7]] -> (tensor<?xf32>) {
+  // CHECK:   %[[VAL_8:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
+  // CHECK:   %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor<?xf32>) outs(%[[VAL_8]] : tensor<?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_10]] : f32
+  // CHECK:   } -> tensor<?xf32>
+  // CHECK:   scf.yield %[[VAL_9]] : tensor<?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG1]] : tensor<?xf32>
+  // CHECK: }
+  // CHECK: %[[VAL_12:.*]] = tensor.empty(%[[ARG0_MAX_DIM]]) : tensor<?xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0_DIM0_BROADCAST]], %[[ARG0_DIM1_BROADCAST]] : tensor<?xf32>, tensor<?xf32>) outs(%[[VAL_12]] : tensor<?xf32>) {
+  // CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32):
+  // CHECK:   %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32
+  // CHECK:   linalg.yield %[[VAL_16]] : f32
+  // CHECK: } -> tensor<?xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
 
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xf32>) outs([[INIT]] : tensor<2x3xf32>) {
-  // CHECK: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = math.absf %[[ARG1]]
-  // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<2x3xf32>
-  %0 = "tosa.abs"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
+// -----
 
-  // CHECK: return [[GENERIC]]
-  return %0 : tensor<2x3xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_dynamic_to_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_dynamic_to_static(%arg0: tensor<5xf32>, %arg1: tensor<?xf32>) -> tensor<5xf32> {
+
+  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[VAL_0:.*]] = arith.cmpi eq, %[[ARG1_DIM0]], %[[CONST1]] : index
+  // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_0]] -> (tensor<?xf32>) {
+  // CHECK:   %[[VAL_1:.*]] = tensor.empty() : tensor<5xf32>
+  // CHECK:   %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor<?xf32>) outs(%[[VAL_1]] : tensor<5xf32>) {
+  // CHECK:   ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_3]] : f32
+  // CHECK:   } -> tensor<5xf32>
+  // CHECK:   %[[VAL_5:.*]] = tensor.cast %[[VAL_2]] : tensor<5xf32> to tensor<?xf32>
+  // CHECK:   scf.yield %[[VAL_5]] : tensor<?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG1]] : tensor<?xf32>
+  // CHECK: }
+  // CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1_DIM0_BROADCAST]] : tensor<5xf32>, tensor<?xf32>) outs(%[[VAL_6]] : tensor<5xf32>) {
+  // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32):
+  // CHECK:   %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
+  // CHECK:   linalg.yield %[[VAL_10]] : f32
+  // CHECK: } -> tensor<5xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<5xf32>, tensor<?xf32>) -> tensor<5xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<5xf32>
+  return %0 : tensor<5xf32>
 }
 
 // -----
 
-// CHECK-LABEL: @test_abs
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: linalg.generic
-  // CHECK: math.absf
-  %0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_static_to_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_static_to_dynamic(%arg0: tensor<1xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+
+  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?xf32>
+  // CHECK: %[[VAL_0:.*]] = tensor.empty(%[[ARG1_DIM0]]) : tensor<?xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<?xf32>) outs(%[[VAL_0]] : tensor<?xf32>) {
+  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+  // CHECK:   linalg.yield %[[VAL_4]] : f32
+  // CHECK: } -> tensor<?xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<?xf32>
   return %0 : tensor<?xf32>
 }
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @test_abs_dyn
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
-  // CHECK: %[[C1:.+]] = arith.constant 1
-  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: linalg.generic
-  // CHECK: math.absf
-  %0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
-  return %0 : tensor<2x?xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_broadcast_static_to_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+
+  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
+  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+  // CHECK:   linalg.yield %[[VAL_4]] : f32
+  // CHECK: } -> tensor<3xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
+  
+  // CHECK: return %[[RESULT]] : tensor<3xf32>
+  return %0 : tensor<3xf32>
 }
 
 // -----
 
-#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
-
-// CHECK-LABEL: @test_encoding_passthrough
-func.func @test_encoding_passthrough(%arg0: tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> {
-  // CHECK: linalg.generic
-  // CHECK: sparse_tensor
-  %0 = "tosa.abs"(%arg0) : (tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector>
-  return %0 : tensor<2xi8, #SparseVector>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @test_add_1d_matching_static
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_1d_matching_static(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+
+  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<3xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<3xf32>, tensor<3xf32>) outs(%[[VAL_0]] : tensor<3xf32>) {
+  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+  // CHECK:   linalg.yield %[[VAL_4]] : f32
+  // CHECK: } -> tensor<3xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<3xf32>
+  return %0 : tensor<3xf32>
 }
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
-
-// CHECK-LABEL: @test_broadcast
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<2xf32>
-func.func @test_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG0]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins([[RESHAPE]], %[[ARG1]] : tensor<f32>, tensor<2xf32>) outs([[INIT]] : tensor<2xf32>) {
-  // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
-  // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<2xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2xf32>) -> tensor<2xf32>
-  return %0 : tensor<2xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (0, d1)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK-LABEL: @test_add_2d_all_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  // CHECK: %[[CONST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK: %[[ARG1_DIM0:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK: %[[MAX_DIM0:.*]] = arith.maxui %[[ARG0_DIM0]], %[[ARG1_DIM0]] : index
+  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
+
+  // CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
+  // CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
+  // CHECK:   %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
+  // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_5]] : f32
+  // CHECK:   } -> tensor<?x?xf32>
+  // CHECK:   scf.yield %[[VAL_4]] : tensor<?x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG0]] : tensor<?x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
+  // CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
+  // CHECK:   %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
+  // CHECK:   %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_12]] : f32
+  // CHECK:   } -> tensor<?x?xf32>
+  // CHECK:   scf.yield %[[VAL_11]] : tensor<?x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
+  // CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
+  // CHECK:   %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK:   %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
+  // CHECK:   %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_19]] : f32
+  // CHECK:   } -> tensor<?x?xf32>
+  // CHECK:   scf.yield %[[VAL_18]] : tensor<?x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG1]] : tensor<?x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
+  // CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
+  // CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
+  // CHECK:   %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
+  // CHECK:   %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
+  // CHECK:   %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_26]] : f32
+  // CHECK:   } -> tensor<?x?xf32>
+  // CHECK:   scf.yield %[[VAL_25]] : tensor<?x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_28:.*]] = tensor.empty(%[[MAX_DIM0]], %[[MAX_DIM1]]) : tensor<?x?xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM1_BROADCAST]], %[[ARG1_DIM1_BROADCAST]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[VAL_28]] : tensor<?x?xf32>) {
+  // CHECK: ^bb0(%[[VAL_29:.*]]: f32, %[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32):
+  // CHECK:   %[[VAL_32:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32
+  // CHECK:   linalg.yield %[[VAL_32]] : f32
+  // CHECK: } -> tensor<?x?xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
 
 // -----
 
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
-
-// CHECK-LABEL: @test_broadcast_swapped_args
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<2xf32
-// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xf32>
-func.func @test_broadcast_swapped_args(%arg0: tensor<2xf32>, %arg1: tensor<1xf32>) -> tensor<2xf32> {
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xf32>
-  // CHECK: [[RESHAPE:%.+]] = "tosa.reshape"(%[[ARG1]])
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[RESHAPE]] : tensor<2xf32>, tensor<f32>) outs([[INIT]] : tensor<2xf32>) {
-  // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
-  // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<2xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
-  return %0 : tensor<2xf32>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (0, d1, d2)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: @test_add_2d_
diff erent_ranks
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @test_add_2d_
diff erent_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+
+  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
+  // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
+  // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
+  // CHECK:   %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32
+  // CHECK:   linalg.yield %[[VAL_4]] : f32
+  // CHECK: } -> tensor<2x3x4xf32>
+  %0 = "tosa.add"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+  
+  // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
 }
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_select_2d_one_dynamic
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]:
+func.func @test_select_2d_one_dynamic(%arg0: tensor<2x?xi1>, %arg1: tensor<2x?xf32>, %arg2: tensor<2x?xf32>) -> tensor<2x?xf32> {
+
+  // CHECK: %[[CONST1:.*]] = arith.constant 1 : index
+  // CHECK: %[[ARG0_DIM1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
+  // CHECK: %[[ARG1_DIM1:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
+  // CHECK: %[[VAL_0:.*]] = arith.maxui %[[ARG0_DIM1]], %[[ARG1_DIM1]] : index
+  // CHECK: %[[ARG2_DIM1:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
+  // CHECK: %[[MAX_DIM1:.*]] = arith.maxui %[[VAL_0]], %[[ARG2_DIM1]] : index
+
+  // CHECK: %[[VAL_1:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<2x?xi1>
+  // CHECK: %[[VAL_2:.*]] = arith.cmpi eq, %[[VAL_1]], %[[CONST1]] : index
+  // CHECK: %[[ARG0_BROADCAST:.*]] = scf.if %[[VAL_2]] -> (tensor<2x?xi1>) {
+  // CHECK:   %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xi1>
+  // CHECK:   %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x?xi1>) outs(%[[VAL_3]] : tensor<2x?xi1>) {
+  // CHECK:   ^bb0(%[[VAL_5:.*]]: i1, %[[VAL_6:.*]]: i1):
+  // CHECK:     linalg.yield %[[VAL_5]] : i1
+  // CHECK:   } -> tensor<2x?xi1>
+  // CHECK:   scf.yield %[[VAL_4]] : tensor<2x?xi1>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG0]] : tensor<2x?xi1>
+  // CHECK: }
 
-// CHECK-LABEL: @test_multibroadcast
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
-// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]
-func.func @test_multibroadcast(%arg0: tensor<1x3xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x3xf32> {
-  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2x3xf32>
-  // CHECK: [[RESHAPE1:%.+]] = "tosa.reshape"(%[[ARG0]]) <{new_shape = array<i64: 3>}
-  // CHECK: [[RESHAPE2:%.+]] = "tosa.reshape"(%[[ARG1]]) <{new_shape = array<i64: 2>}
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RESHAPE1]], [[RESHAPE2]] : tensor<3xf32>, tensor<2xf32>) outs([[INIT]] : tensor<2x3xf32>) {
-  // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32):
-  // CHECK:   [[ELEMENT:%.+]] = arith.addf %[[ARG2]], %[[ARG3]] : f32
-  // CHECK:   linalg.yield [[ELEMENT]] : f32
-  // CHECK: } -> tensor<2x3xf32>
-  %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
-  return %0 : tensor<2x3xf32>
+  // CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<2x?xf32>
+  // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
+  // CHECK: %[[ARG1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<2x?xf32>) {
+  // CHECK:   %[[VAL_9:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+  // CHECK:   %[[VAL_10:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<2x?xf32>) outs(%[[VAL_9]] : tensor<2x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_11]] : f32
+  // CHECK:   } -> tensor<2x?xf32>
+  // CHECK:   scf.yield %[[VAL_10]] : tensor<2x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG1]] : tensor<2x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_13:.*]] = tensor.dim %[[ARG2]], %[[CONST1]] : tensor<2x?xf32>
+  // CHECK: %[[VAL_14:.*]] = arith.cmpi eq, %[[VAL_13]], %[[CONST1]] : index
+  // CHECK: %[[ARG2_BROADCAST:.*]] = scf.if %[[VAL_14]] -> (tensor<2x?xf32>) {
+  // CHECK:   %[[VAL_15:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+  // CHECK:   %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG2]] : tensor<2x?xf32>) outs(%[[VAL_15]] : tensor<2x?xf32>) {
+  // CHECK:   ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32):
+  // CHECK:     linalg.yield %[[VAL_17]] : f32
+  // CHECK:   } -> tensor<2x?xf32>
+  // CHECK:   scf.yield %[[VAL_16]] : tensor<2x?xf32>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[ARG2]] : tensor<2x?xf32>
+  // CHECK: }
+
+  // CHECK: %[[VAL_19:.*]] = tensor.empty(%[[MAX_DIM1]]) : tensor<2x?xf32>
+  // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_BROADCAST]], %[[ARG1_BROADCAST]], %[[ARG2_BROADCAST]] : tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) outs(%[[VAL_19]] : tensor<2x?xf32>) {
+  // CHECK: ^bb0(%[[VAL_20:.*]]: i1, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32):
+  // CHECK:   %[[VAL_24:.*]] = arith.select %[[VAL_20]], %[[VAL_21]], %[[VAL_22]] : f32
+  // CHECK:   linalg.yield %[[VAL_24]] : f32
+  // CHECK: } -> tensor<2x?xf32>
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
+
+  // CHECK: return %[[RESULT]] : tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
 }
 
 // -----
@@ -1412,20 +1656,6 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 
 // -----
 
-// Regression test for using the wrong rank.
-
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> ()>
-// CHECK-LABEL: @select_fp32
-func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %arg2: tensor<f32>) -> tensor<1x12x5x5xf32> {
-  // CHECK: linalg.generic
-  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<f32>) -> tensor<1x12x5x5xf32>
-  return %0 : tensor<1x12x5x5xf32>
-}
-
-// -----
-
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 

diff  --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir
index 83e3ea4c64051b..217ce84c83bcba 100644
--- a/mlir/test/Dialect/traits.mlir
+++ b/mlir/test/Dialect/traits.mlir
@@ -111,9 +111,18 @@ func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> t
 
 // -----
 
-func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x6xi32> {
-  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x6xi32>
-  return %0 : tensor<?x6x6xi32>
+// Error for inferred dynamic dimension but existing static dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<2xi32> {
+  // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}}
+  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x?xi32> {
+  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x?xi32>
+  return %0 : tensor<?x6x?xi32>
 }
 
 // -----
@@ -145,10 +154,19 @@ func.func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> ten
 
 // -----
 
-func.func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
-^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
-  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
-  return %0 : tensor<8x7x6x5xi32>
+// Correct use of broadcast semantics for input dimensions
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
+  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
+  return %0 : tensor<?x7x6x5xi32>
+}
+
+// -----
+
+// Incorrect attempt to use broadcast semantics for result
+func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> {
+  // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}}
+  %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32>
+  return %0 : tensor<5xi32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list