[Mlir-commits] [mlir] 9dd15f7 - [mlir][tosa] Add aggressiveReduceConstant argument for the constant reduce optimization (#68765)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 11 22:48:58 PDT 2023
Author: Amir Bishara
Date: 2023-10-12T08:48:54+03:00
New Revision: 9dd15f7486a30c4269b183f72c13006eb8c929f4
URL: https://github.com/llvm/llvm-project/commit/9dd15f7486a30c4269b183f72c13006eb8c929f4
DIFF: https://github.com/llvm/llvm-project/commit/9dd15f7486a30c4269b183f72c13006eb8c929f4.diff
LOG: [mlir][tosa] Add aggressiveReduceConstant argument for the constant reduce optimization (#68765)
Adding the argument of aggressiveReduceConstant to the
TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the reduce ops always.
(e.g. without considering the
number of users of the input of the reduce operation)
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..bb56c8d203d3c15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1082,6 +1082,14 @@ def TosaToLinalg
}];
let constructor = "tosa::createTosaToLinalg()";
+ let options = [
+ Option<"disableTosaDecompositions", "disable-tosa-decompositions",
+ "bool", /*default=*/"false",
+ "Disable tosa decompositions pass">,
+ Option<"aggressiveReduceConstant", "aggressive-reduce-constant",
+ "bool", /*default=*/"false",
+ "Always perform the reduce constant optimization">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 818d43ffe4e572e..d8d4027500f99c6 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -33,7 +33,7 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
/// pipeline succeeds. The option to disable decompositions is available for
/// benchmarking performance improvements from the canonicalizations.
void addTosaToLinalgPasses(
- OpPassManager &pm, bool disableTosaDecompositions = false,
+ OpPassManager &pm, const TosaToLinalgOptions &options,
// Note: Default to 'none' level unless otherwise specified.
tosa::ValidationOptions const &validationOptions =
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 6b5dd9c970703ee..940aed107e2f916 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -35,9 +35,12 @@ void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaConstantReduction(MLIRContext *ctx,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ bool aggressiveReduceConstant);
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
+std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(
+ const TosaLayerwiseConstantFoldPassOptions &options);
std::unique_ptr<Pass> createTosaInferShapesPass();
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..ac100a6d75c7c08 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -23,6 +23,13 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
}];
let constructor = "createTosaLayerwiseConstantFoldPass()";
+
+ let options = [
+ Option<"aggressiveReduceConstant", "aggressive-reduce-constant", "bool",
+ /*default=*/"false",
+ "Always perform the reduce constant optimization"
+ "May add more tosa.const but would reduce runtime calculations">,
+ ];
}
def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index d7e867d92282395..718e34ced8d7e70 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -75,10 +75,10 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
}
void mlir::tosa::addTosaToLinalgPasses(
- OpPassManager &pm, bool disableTosaDecompositions,
+ OpPassManager &pm, const TosaToLinalgOptions &options,
tosa::ValidationOptions const &validationOptions) {
// Optional decompositions are designed to benefit linalg.
- if (!disableTosaDecompositions)
+ if (!options.disableTosaDecompositions)
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
@@ -87,7 +87,8 @@ void mlir::tosa::addTosaToLinalgPasses(
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// TODO: Remove pass that operates on const tensor and enable optionality
- pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
+ pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
+ {options.aggressiveReduceConstant}));
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(
tosa::createTosaValidationPass(validationOptions));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0988759b82201df..d35e911ebe63c42 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -350,6 +350,11 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
+ ReduceConstantOptimization(MLIRContext *context,
+ bool aggressiveReduceConstant)
+ : OpRewritePattern<OperationType>(context),
+ aggressiveReduceConstant(aggressiveReduceConstant) {}
+
using OpRewritePattern<OperationType>::OpRewritePattern;
LogicalResult matchAndRewrite(OperationType op,
@@ -361,7 +366,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
return rewriter.notifyMatchFailure(
op, "reduce input must be const operation");
- if (!inputOp.hasOneUse())
+ if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
return rewriter.notifyMatchFailure(
op, "input operation has more than one user");
@@ -400,18 +405,26 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
return success();
}
+ const bool aggressiveReduceConstant;
};
} // namespace
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
- RewritePatternSet &patterns) {
- patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx);
- patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx);
+ RewritePatternSet &patterns,
+ bool aggressiveReduceConstant) {
+ patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
+ ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
+ ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
+ ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
+ ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
+ ctx, aggressiveReduceConstant);
+ patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
+ ctx, aggressiveReduceConstant);
}
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index 90f15faf0108103..e1400f0c907b2cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -45,6 +45,10 @@ void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
struct TosaLayerwiseConstantFoldPass
: public tosa::impl::TosaLayerwiseConstantFoldPassBase<
TosaLayerwiseConstantFoldPass> {
+ TosaLayerwiseConstantFoldPass(
+ const TosaLayerwiseConstantFoldPassOptions &options)
+ : TosaLayerwiseConstantFoldPassBase(options) {}
+
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
@@ -52,7 +56,8 @@ struct TosaLayerwiseConstantFoldPass
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
- mlir::tosa::populateTosaConstantReduction(ctx, patterns);
+ mlir::tosa::populateTosaConstantReduction(ctx, patterns,
+ aggressiveReduceConstant);
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
@@ -63,5 +68,11 @@ struct TosaLayerwiseConstantFoldPass
} // namespace
std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
- return std::make_unique<TosaLayerwiseConstantFoldPass>();
+ return std::make_unique<TosaLayerwiseConstantFoldPass>(
+ TosaLayerwiseConstantFoldPassOptions{false});
+}
+
+std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(
+ const TosaLayerwiseConstantFoldPassOptions &options) {
+ return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 56619fbc560e5fa..612e99f198515ae 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
+
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
+
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@@ -1051,3 +1054,57 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
%0 = tosa.reduce_sum %arg2 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
return %0 : tensor<1x3xi32>
}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // AGGRESIVE: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+ // AGGRESIVE: return %[[VAL_0:.*]] : tensor<1x3xi32>
+
+ // CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+ // CHECK: return %[[VAL_3]] : tensor<1x3xi32>
+
+ %const = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %0 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ %1 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+ %res = tosa.add %0, %1 : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+ return %res : tensor<1x3xi32>
+}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // AGGRESIVE: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
+ // AGGRESIVE: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+ // AGGRESIVE: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // AGGRESIVE: return %[[VAL_6]] : tensor<2x3xi32>
+
+ // CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+ // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+ // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+ // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+ // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: return %[[VAL_6]] : tensor<2x3xi32>
+
+ %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
+ %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+ %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+ %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+ %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+ %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ return %res1 : tensor<2x3xi32>
+}
More information about the Mlir-commits
mailing list