[Mlir-commits] [mlir] [mlir][TOSA] Add folder for multiply like reduce_prod operation (PR #128067)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 13:01:35 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values.
---
Full diff: https://github.com/llvm/llvm-project/pull/128067.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+25-1)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..45bcacff19caa 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -963,10 +963,34 @@ REDUCE_FOLDER(ReduceAllOp)
REDUCE_FOLDER(ReduceAnyOp)
REDUCE_FOLDER(ReduceMaxOp)
REDUCE_FOLDER(ReduceMinOp)
-REDUCE_FOLDER(ReduceProdOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER
+OpFoldResult ReduceProdOp::fold(FoldAdaptor adaptor) {
+ ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());
+ if (!inputTy.hasRank())
+ return {};
+ if (inputTy == getType() && (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
+ return getInput();
+
+ // Fold multiply like reduce_prod operators using mulBinaryFolder
+ if (inputTy.getRank() == 1 && inputTy.getDimSize(0) == 2) {
+ const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!resultTy)
+ return {};
+
+ const auto elements = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
+ if (!elements)
+ return {};
+
+ const auto lhsAttr = DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
+ const auto rhsAttr = DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
+ return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
+ }
+
+ return {};
+}
+
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 0e177a076ee7a..316f22f88fc69 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1012,3 +1012,14 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
return %2 : tensor<3x600x1200xi32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_reduce_prod_is_mul
+func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
+ // CHECK: return %[[VAL_0]] : tensor<1xi32>
+ %0 = "tosa.const"() <{value = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
+ %1 = "tosa.reduce_prod"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
+ return %1 : tensor<1xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/128067
More information about the Mlir-commits
mailing list