[Mlir-commits] [mlir] [mlir][tosa] Improve broadcasting behaviour in elementwise folders (PR #181114)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 12 02:19:26 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit aims to improve elementwise folder behaviour when broadcasting is involved. In particular, it ensures correctness of folders when the input operands are dynamic and it is not clear whether broadcasting is involved.

For example, previously, the tosa.add folder could result in shape information loss:
```
func.func @<!-- -->test(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
  %one = "tosa.const"() {values = dense<0> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
  %div = tosa.add %one, %arg0 : (tensor<2x1xi32>, tensor<?x4xi32>) -> tensor<?x4xi32>
  return %div : tensor<?x4xi32>
}

$ mlir-opt --canonicalize test.mlir

func.func @<!-- -->test(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
  return %arg0 : tensor<?x4xi32>
}
```
In this example we have lost the contraint on the first dimension provided by the lhs operand. This commit ensures these cases are not folded.

Similar logic has also been applied to tosa.sub, tosa.intdiv, tosa.mul.

tosa.select previously handled the above example correctly, but did this by failing to fold when dynamic shapes were involved (a conservative approach). This commit improves the folder such that it follows the behaviour of the other elementwise folders.

Note: some tests have been moved from canonicalize.mlir to constant_folding.mlir to help keep folding tests together.

---

Patch is 33.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/181114.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+61-39) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (-166) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+384) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 42033ce8a3b02..230c7d8bdb15f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/Dialect/Traits.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -1213,9 +1214,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+  const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
+      lhsTy.getShape(), rhsTy.getShape());
+  if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
-  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
+  if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
     return getInput2();
 
   if (!lhsAttr || !rhsAttr)
@@ -1247,7 +1250,7 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
   if (!lhsTy || !rhsTy || !resultTy)
     return {};
-  if (lhsTy != rhsTy)
+  if (lhsTy.getElementType() != rhsTy.getElementType())
     return {};
 
   // IntDivOp inputs must be integer type, no need to check for quantized type
@@ -1257,13 +1260,16 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   if (lhsAttr && lhsAttr.isSplat()) {
-    if (llvm::isa<IntegerType>(resultETy) &&
+    if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
         lhsAttr.getSplatValue<APInt>().isZero())
-      return lhsAttr;
+      return lhsAttr.resizeSplat(resultTy);
   }
 
   if (rhsAttr && rhsAttr.isSplat()) {
-    if (llvm::isa<IntegerType>(resultETy) &&
+    const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
+        lhsTy.getShape(), rhsTy.getShape());
+    if (isBroadcastable && lhsTy == resultTy &&
+        llvm::isa<IntegerType>(resultETy) &&
         rhsAttr.getSplatValue<APInt>().isOne())
       return getInput1();
   }
@@ -1369,19 +1375,22 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
     }
   }
 
-  if (rhsTy == resultTy) {
-    if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
-      // constant values can only be resized if resulting type is static
-      return lhsAttr.resizeSplat(resultTy);
-    if (isSplatOne(resultETy, lhsAttr, shift))
-      return rhs;
-  }
-  if (lhsTy == resultTy) {
-    if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
-      return rhsAttr.resizeSplat(resultTy);
-    if (isSplatOne(resultETy, rhsAttr, shift))
-      return lhs;
-  }
+  if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
+      resultTy.hasStaticShape())
+    // constant values can only be resized if resulting type is static
+    return lhsAttr.resizeSplat(resultTy);
+  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
+      resultTy.hasStaticShape())
+    return rhsAttr.resizeSplat(resultTy);
+
+  const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
+      lhsTy.getShape(), rhsTy.getShape());
+  if (isBroadcastable && rhsTy == resultTy &&
+      isSplatOne(resultETy, lhsAttr, shift))
+    return rhs;
+  if (isBroadcastable && lhsTy == resultTy &&
+      isSplatOne(resultETy, rhsAttr, shift))
+    return lhs;
 
   return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
 }
@@ -1404,7 +1413,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
+  const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
+      lhsTy.getShape(), rhsTy.getShape());
+  if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
 
   if (!lhsAttr || !rhsAttr)
@@ -1717,36 +1728,47 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-static bool
-mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) {
-  const auto isDynamic = [](Type ty) {
-    const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
-    return !shapedTy || !shapedTy.hasStaticShape();
-  };
-
-  return llvm::any_of(operandTypes, isDynamic) ||
-         failed(verifyCompatibleShapes(operandTypes));
-}
-
 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
-  // Select allows operand shapes to be broadcast to the output shape. For
-  // now, don't support folding when we cannot prove no broadcasting is
-  // involved.
-  if (mayRequireBroadcast(getOperandTypes()))
+  const Value pred = getPred();
+  const Value onTrue = getOnTrue();
+  const Value onFalse = getOnFalse();
+
+  const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
+  const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
+  const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
+  if (!predTy || !onTrueTy || !onFalseTy)
     return {};
 
-  if (getOnTrue() == getOnFalse())
+  const Type resultTy = getType();
+
+  const ArrayRef<int64_t> predShape = predTy.getShape();
+  const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
+
+  if (onTrue == onFalse && onTrueTy == resultTy &&
+      OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
     return getOnTrue();
 
   auto predicate =
       llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
   if (!predicate)
     return {};
-
   if (!predicate.isSplat())
     return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
-                                                         : getOnFalse();
+
+  const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
+
+  SmallVector<SmallVector<int64_t>, 3> shapes;
+  shapes.emplace_back(predShape);
+  shapes.emplace_back(onTrueShape);
+  shapes.emplace_back(onFalseTy.getShape());
+  const bool isBroadcastable =
+      OpTrait::util::staticallyKnownBroadcastable(shapes);
+
+  if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
+    return getOnTrue();
+  if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
+    return getOnFalse();
+  return {};
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 81e537babf9ab..e85d349459aaa 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -539,130 +539,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
 
 // -----
 
-// CHECK-LABEL: @mul_one_float
-func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK: return %arg0
-  // CHECK-NOT: tosa.mul
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %ones = "tosa.const"() {values = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
-  return %1 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_bcast_one_float
-func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
-  // CHECK: return %arg0
-  // CHECK-NOT: tosa.mul
-  %ones = "tosa.const"() {values = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
-  return %1 : tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_one_int
-func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // CHECK: return %arg0
-  // CHECK-NOT: tosa.mul
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %ones = "tosa.const"() {values = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
-  return %1 : tensor<2x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_one_int_and_shift
-func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<1> : tensor<2x3xi32>}>
-  // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}>
-  // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>)
-  %ones = "tosa.const"() {values = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %shift = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
-  return %1 : tensor<2x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_zero_broadcast
-func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
-  // CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<2x3xf32>}
-  // CHECK-NOT: tosa.mul
-  %zeros = "tosa.const"() {values = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32>
-
-  // CHECK-NOT: tosa.mul
-  // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32>
-  return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_zero_dynamic_nofold
-// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
-// CHECK:           %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
-// CHECK:           %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-// CHECK:           %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
-// CHECK:           return %[[MUL]]
-func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
-  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
-  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
-  return %2 : tensor<?x17xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @mul_one_dynamic_fold
-// CHECK-SAME:                    %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
-// CHECK:           return %[[ARG0]]
-func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
-  %0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
-  %1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
-  return %2 : tensor<?x17xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_same_value
-func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  // CHECK: return %arg1
-  // CHECK-NOT: tosa.select
-  return %0 : tensor<2x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_true_value
-func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  %c1 = "tosa.const"() {values = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
-  %0 = tosa.select %c1, %arg0, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  // CHECK: return %arg0
-  // CHECK-NOT: tosa.select
-  return %0 : tensor<2x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_false_value
-func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
-  %c0 = "tosa.const"() {values = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
-  %0 = tosa.select %c0, %arg0, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  // CHECK: return %arg1
-  // CHECK-NOT: tosa.select
-  return %0 : tensor<2x3xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @select_not_pred
 func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = tosa.logical_not %arg0 : (tensor<2x3xi1>) -> tensor<2x3xi1>
@@ -673,48 +549,6 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
 
 // -----
 
-// CHECK-LABEL: @select_broadcast_same_value_no_fold
-func.func @select_broadcast_same_value_no_fold(%arg0: tensor<2x2xi1>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> {
-  // CHECK: tosa.select %arg0, %arg1, %arg1
-  %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x2xf32>
-  return %0 : tensor<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_broadcast_true_value_no_fold
-func.func @select_broadcast_true_value_no_fold(%arg0: tensor<1x1xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
-  // CHECK: %[[CONST:.*]] = "tosa.const"
-  %0 = "tosa.const"() {values = dense<1> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
-  // CHECK: tosa.select %[[CONST]], %arg0, %arg1
-  %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_broadcast_false_value_no_fold
-func.func @select_broadcast_false_value_no_fold(%arg0: tensor<2x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> {
-  // CHECK: %[[CONST:.*]] = "tosa.const"
-  %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
-  // CHECK: tosa.select %[[CONST]], %arg0, %arg1
-  %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<1x1xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @select_broadcast_false_value_dynamic_operand_no_fold
-func.func @select_broadcast_false_value_dynamic_operand_no_fold(%arg0: tensor<2x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // CHECK: %[[CONST:.*]] = "tosa.const"
-  %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1>
-  // CHECK: tosa.select %[[CONST]], %arg0, %arg1
-  %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
-  return %1 : tensor<2x2xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @reduce_all_fold
 func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index c3186279a30ae..5fa5f1d5143f3 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -136,6 +136,56 @@ func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> {
 
 // -----
 
+// CHECK-LABEL: @no_fold_add_unknown_broadcast_zero_lhs
+func.func @no_fold_add_unknown_broadcast_zero_lhs(%arg0: tensor<1x4xi32>) -> tensor<2x4xi32> {
+  %one = "tosa.const"() {values = dense<0> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
+  %div = tosa.add %one, %arg0 : (tensor<2x1xi32>, tensor<1x4xi32>) -> tensor<2x4xi32>
+  // CHECK: tosa.add
+  return %div : tensor<2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_dynamic_add_unknown_broadcast_zero_lhs
+func.func @no_fold_dynamic_add_unknown_broadcast_zero_lhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
+  %one = "tosa.const"() {values = dense<0> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
+  %div = tosa.add %one, %arg0 : (tensor<2x1xi32>, tensor<?x4xi32>) -> tensor<?x4xi32>
+  // CHECK: tosa.add
+  return %div : tensor<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_dynamic_add_broadcast_zero_lhs
+func.func @fold_dynamic_add_broadcast_zero_lhs(%arg0: tensor<?x17xi32>) -> tensor<?x17xi32> {
+  %one = "tosa.const"() {values = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
+  %div = tosa.add %one, %arg0 : (tensor<1x1xi32>, tensor<?x17xi32>) -> tensor<?x17xi32>
+  // CHECK: return %arg0
+  return %div : tensor<?x17xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_dynamic_add_unknown_broadcast_zero_rhs
+func.func @no_fold_dynamic_add_unknown_broadcast_zero_rhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
+  %one = "tosa.const"() {values = dense<0> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
+  %div = tosa.add %arg0, %one : (tensor<?x4xi32>, tensor<2x1xi32>) -> tensor<?x4xi32>
+  // CHECK: tosa.add
+  return %div : tensor<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_dynamic_add_broadcast_zero_rhs
+func.func @fold_dynamic_add_broadcast_zero_rhs(%arg0: tensor<?x17xi32>) -> tensor<?x17xi32> {
+  %one = "tosa.const"() {values = dense<0> : tensor<1x17xi32>} : () -> tensor<1x17xi32>
+  %div = tosa.add %arg0, %one : (tensor<?x17xi32>, tensor<1x17xi32>) -> tensor<?x17xi32>
+  // CHECK: return %arg0
+  return %div : tensor<?x17xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_div_zero_lhs_i32
 func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -147,6 +197,47 @@ func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 
 // -----
 
+// CHECK-LABEL: @no_fold_dynamic_div_zero_lhs
+func.func @no_fold_dynamic_div_zero_lhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %zero, %arg0 : (tensor<1x1xi32>, tensor<?x4xi32>) -> tensor<?x4xi32>
+  return %div : tensor<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_zero_lhs_broadcast
+func.func @fold_div_zero_lhs_broadcast(%arg0: tensor<2x4xi32>) -> tensor<2x4xi32> {
+  %zero = "tosa.const"() {values = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> : tensor<2x4xi32>
+  %div = tosa.intdiv %zero, %arg0 : (tensor<1x1xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
+  // CHECK: return %[[ZERO]]
+  return %div : tensor<2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_div_unknown_broadcast_one_rhs
+func.func @no_fold_div_unknown_broadcast_one_rhs(%arg0: tensor<1x4xi32>) -> tensor<2x4xi32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %arg0, %one : (tensor<1x4xi32>, tensor<2x1xi32>) -> tensor<2x4xi32>
+  return %div : tensor<2x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_broadcast_one_rhs
+func.func @fold_div_broadcast_one_rhs(%arg0: tensor<?x17xi32>) -> tensor<?x17xi32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
+  %div = tosa.intdiv %arg0, %one : (tensor<?x17xi32>, tensor<1x1xi32>) -> tensor<?x17xi32>
+  // CHECK: return %arg0
+  return %div : tensor<?x17xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_div_one_rhs_i32
 func.func @fold_div_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %one = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
@@ -157,6 +248,16 @@ func.func @fold_div_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 
 // -----
 
+// CHECK-LABEL: @no_fold_dynamic_div_unknown_broadcast_one_rhs
+func.func @no_fold_dynamic_div_unknown_broadcast_one_rhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi32> {
+  %one = "tosa.const"() {values = dense<1> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %arg0, %one : (tensor<?x4xi32>, tensor<2x1xi32>) -> tensor<?x4xi32>
+  return %div : tensor<?x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_div_splat_i32
 func.func @fold_div_splat_i32() -> tensor<i32> {
   %lhs = "tosa.const"() {values = dense<10> : tensor<i32>} : () -> tensor<i32>
@@ -229,6 +330,28 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 
 // -----
 
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/181114


More information about the Mlir-commits mailing list