[Mlir-commits] [mlir] [mlir][TOSA] Add folder for multiply like reduce_prod operation (PR #128067)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 13:01:35 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values.


---
Full diff: https://github.com/llvm/llvm-project/pull/128067.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+25-1) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..45bcacff19caa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -963,10 +963,34 @@ REDUCE_FOLDER(ReduceAllOp)
 REDUCE_FOLDER(ReduceAnyOp)
 REDUCE_FOLDER(ReduceMaxOp)
 REDUCE_FOLDER(ReduceMinOp)
-REDUCE_FOLDER(ReduceProdOp)
 REDUCE_FOLDER(ReduceSumOp)
 #undef REDUCE_FOLDER
 
+OpFoldResult ReduceProdOp::fold(FoldAdaptor adaptor) {
+  ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());
+  if (!inputTy.hasRank())
+    return {};
+  if (inputTy == getType() && (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
+    return getInput();
+
+  // Fold multiply like reduce_prod operators using mulBinaryFolder
+  if (inputTy.getRank() == 1 && inputTy.getDimSize(0) == 2) {
+    const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+    if (!resultTy)
+      return {};
+
+    const auto elements = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
+    if (!elements)
+      return {};
+
+    const auto lhsAttr = DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
+    const auto rhsAttr = DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
+    return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
+  }
+
+  return {};
+}
+
 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 0e177a076ee7a..316f22f88fc69 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1012,3 +1012,14 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
   %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
   return %2 : tensor<3x600x1200xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_reduce_prod_is_mul
+func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
+  // CHECK: return %[[VAL_0]] : tensor<1xi32>
+  %0 = "tosa.const"() <{value = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
+  %1 = "tosa.reduce_prod"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
+  return %1 : tensor<1xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/128067


More information about the Mlir-commits mailing list