[Mlir-commits] [mlir] [mlir][tosa] Add aggressiveReduceConstant argument for the constant reduce optimization (PR #68765)

Amir Bishara llvmlistbot at llvm.org
Tue Oct 10 23:18:52 PDT 2023


https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/68765

>From fd3fb470ed1aa96cde2a34bf2dc22d2fd6617372 Mon Sep 17 00:00:00 2001
From: amirBish <amir.bishara at mobileye.com>
Date: Sat, 7 Oct 2023 19:14:50 +0300
Subject: [PATCH] [mlir][tosa] Add aggressiveReduceConstant argument for the
 constant reduce optimization

Adding the argument of aggressiveReduceConstant to
the TosaLayerwiseConstantFoldPass which would
allow performing the constant optimizations on the
reduce ops always. (e.g. without considering the
number of users of the input of the reduce operation)
---
 mlir/include/mlir/Conversion/Passes.td        |  8 +++
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  2 +-
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |  3 +-
 .../mlir/Dialect/Tosa/Transforms/Passes.td    |  7 +++
 .../TosaToLinalg/TosaToLinalgPass.cpp         |  6 +-
 .../Dialect/Tosa/Transforms/TosaFolders.cpp   | 21 ++++---
 .../TosaLayerwiseConstantFoldPass.cpp         | 12 +++-
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 57 +++++++++++++++++++
 8 files changed, 101 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..bb56c8d203d3c15 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1082,6 +1082,14 @@ def TosaToLinalg
   }];
 
   let constructor = "tosa::createTosaToLinalg()";
+  let options = [
+    Option<"disableTosaDecompositions", "disable-tosa-decompositions",
+           "bool", /*default=*/"false",
+           "Disable tosa decompositions pass">,
+    Option<"aggressiveReduceConstant", "aggressive-reduce-constant",
+           "bool", /*default=*/"false",
+           "Always perform the reduce constant optimization">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 818d43ffe4e572e..8ffbd1238e5c6b3 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -33,7 +33,7 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
 /// pipeline succeeds.  The option to disable decompositions is available for
 /// benchmarking performance improvements from the canonicalizations.
 void addTosaToLinalgPasses(
-    OpPassManager &pm, bool disableTosaDecompositions = false,
+    OpPassManager &pm, const TosaToLinalgOptions& options,
     // Note: Default to 'none' level unless otherwise specified.
     tosa::ValidationOptions const &validationOptions =
         tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 6b5dd9c970703ee..8f3255ddaad6844 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -35,9 +35,10 @@ void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
 void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
                                                RewritePatternSet &patterns);
 void populateTosaConstantReduction(MLIRContext *ctx,
-                                   RewritePatternSet &patterns);
+                                   RewritePatternSet &patterns,bool aggressiveReduceConstant);
 
 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
+std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options);
 std::unique_ptr<Pass> createTosaInferShapesPass();
 std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
 std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 18402b3e70647a9..ac100a6d75c7c08 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -23,6 +23,13 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::
   }];
 
   let constructor = "createTosaLayerwiseConstantFoldPass()";
+
+  let options = [
+      Option<"aggressiveReduceConstant", "aggressive-reduce-constant", "bool",
+             /*default=*/"false",
+             "Always perform the reduce constant optimization"
+             "May add more tosa.const but would reduce runtime calculations">,
+   ];
 }
 
 def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> {
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index d7e867d92282395..e934d21fe065959 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -75,10 +75,10 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
 }
 
 void mlir::tosa::addTosaToLinalgPasses(
-    OpPassManager &pm, bool disableTosaDecompositions,
+    OpPassManager &pm, const TosaToLinalgOptions& options,
     tosa::ValidationOptions const &validationOptions) {
   // Optional decompositions are designed to benefit linalg.
-  if (!disableTosaDecompositions)
+  if (!options.disableTosaDecompositions)
     pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
 
@@ -87,7 +87,7 @@ void mlir::tosa::addTosaToLinalgPasses(
   pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalgNamed());
   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   // TODO: Remove pass that operates on const tensor and enable optionality
-  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass());
+  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass({options.aggressiveReduceConstant}));
   pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
   pm.addNestedPass<func::FuncOp>(
       tosa::createTosaValidationPass(validationOptions));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0988759b82201df..3b417eda9e20dd4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -350,6 +350,9 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
 template <typename OperationType>
 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
 
+  ReduceConstantOptimization(MLIRContext *context, bool aggressiveReduceConstant):
+  OpRewritePattern<OperationType>(context), aggressiveReduceConstant(aggressiveReduceConstant){}
+
   using OpRewritePattern<OperationType>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(OperationType op,
@@ -361,7 +364,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
       return rewriter.notifyMatchFailure(
           op, "reduce input must be const operation");
 
-    if (!inputOp.hasOneUse())
+    if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
       return rewriter.notifyMatchFailure(
           op, "input operation has more than one user");
 
@@ -400,18 +403,20 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
     return success();
   }
+  const bool aggressiveReduceConstant;
 };
 
 } // namespace
 
 void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
-                                               RewritePatternSet &patterns) {
-  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx);
-  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx);
+                                               RewritePatternSet &patterns,
+                                               bool aggressiveReduceConstant) {
+  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx, aggressiveReduceConstant);
+  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx, aggressiveReduceConstant);
 }
 
 void mlir::tosa::populateTosaFoldConstantTransposePatterns(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index 90f15faf0108103..56bc53a4746da0d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -45,6 +45,10 @@ void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
 struct TosaLayerwiseConstantFoldPass
     : public tosa::impl::TosaLayerwiseConstantFoldPassBase<
           TosaLayerwiseConstantFoldPass> {
+  TosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options):TosaLayerwiseConstantFoldPassBase(options)
+  {
+  }
+
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
@@ -52,7 +56,7 @@ struct TosaLayerwiseConstantFoldPass
 
     mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
-    mlir::tosa::populateTosaConstantReduction(ctx, patterns);
+    mlir::tosa::populateTosaConstantReduction(ctx, patterns, aggressiveReduceConstant);
     populateTosaOpsCanonicalizationPatterns(ctx, patterns);
 
     if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
@@ -63,5 +67,9 @@ struct TosaLayerwiseConstantFoldPass
 } // namespace
 
 std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass() {
-  return std::make_unique<TosaLayerwiseConstantFoldPass>();
+return std::make_unique<TosaLayerwiseConstantFoldPass>(TosaLayerwiseConstantFoldPassOptions{false});
+}
+
+std::unique_ptr<Pass> mlir::tosa::createTosaLayerwiseConstantFoldPass(const TosaLayerwiseConstantFoldPassOptions& options) {
+  return std::make_unique<TosaLayerwiseConstantFoldPass>(options);
 }
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 56619fbc560e5fa..612e99f198515ae 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s
 
+
+// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
+
 // CHECK-LABEL: @transpose_fold
 func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -1051,3 +1054,57 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
   %0 = tosa.reduce_sum %arg2 {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
   return %0 : tensor<1x3xi32>
 }
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // AGGRESIVE:       %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+  // AGGRESIVE:       return %[[VAL_0:.*]] : tensor<1x3xi32>
+  
+  // CHECK-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+  // CHECK:           return %[[VAL_3]] : tensor<1x3xi32>
+
+  %const = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %0 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  %1 = tosa.reduce_sum %const {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
+  %res = tosa.add %0, %1 : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32>
+  return %res : tensor<1x3xi32>
+}
+
+// -----
+
+func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // AGGRESIVE-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // AGGRESIVE:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
+  // AGGRESIVE:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+  // AGGRESIVE:           %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           return %[[VAL_6]] : tensor<2x3xi32>
+
+  // CHECK-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
+  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
+  // CHECK:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+  // CHECK:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           return %[[VAL_6]] : tensor<2x3xi32>
+
+  %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
+  %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
+  %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
+  %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %res1 : tensor<2x3xi32>
+}



More information about the Mlir-commits mailing list