[Mlir-commits] [mlir] b5aff11 - [mlir][tosa] Add folding for TOSA ArgMax operator (#88871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 23:33:54 PDT 2024


Author: Dmitrii Agibov
Date: 2024-04-18T07:33:51+01:00
New Revision: b5aff11aa118dabf134a1377dfd94b34e4dedbf7

URL: https://github.com/llvm/llvm-project/commit/b5aff11aa118dabf134a1377dfd94b34e4dedbf7
DIFF: https://github.com/llvm/llvm-project/commit/b5aff11aa118dabf134a1377dfd94b34e4dedbf7.diff

LOG: [mlir][tosa] Add folding for TOSA ArgMax operator (#88871)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..dde17e2dc8924d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -49,6 +49,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
     Tosa_Tensor: $output
   );
 
+  let hasFolder = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index d23c9fe824c94a..c8bf4c526b239f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -507,6 +507,19 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
                                                             resultTy);
 }
 
+OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
+  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
+      !outputTy.hasStaticShape())
+    return {};
+
+  if (inputTy.getDimSize(getAxis()) == 1)
+    return DenseElementsAttr::get(outputTy, 0);
+
+  return {};
+}
+
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index de752f31fcbaa1..c9c60a94bf9ed0 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -3,6 +3,20 @@
 
 // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
 
+// CHECK-LABEL: @armax_fold_dim_size_1
+func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
+  // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
+  %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<2x1x3xf32>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}
+
+// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1
+func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
+  // CHECK: tosa.argmax
+  %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<?x1x3xf32>) -> tensor<?x3xi32>
+  return %0 : tensor<?x3xi32>
+}
+
 // CHECK-LABEL: @transpose_fold
 func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
@@ -1100,9 +1114,9 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
   // AGGRESIVE-DAG:       %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
   // AGGRESIVE-DAG:       %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
   // AGGRESIVE-DAG:       %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
-  // AGGRESIVE:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
   // AGGRESIVE:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
-  // AGGRESIVE:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // AGGRESIVE:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   // AGGRESIVE:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   // AGGRESIVE:           return %[[VAL_6]] : tensor<2x3xi32>
 
@@ -1110,18 +1124,18 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
   // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
   // CHECK:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   // CHECK:           %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
-  // CHECK:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
   // CHECK:           %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK:           %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   // CHECK:           %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   // CHECK:           return %[[VAL_6]] : tensor<2x3xi32>
 
   %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
   %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
   %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
-  %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
+  %argmax0 = tosa.argmax %reduce0 {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
   %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
-  %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %res0 = tosa.add %argmax0, %const1 : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %res1 : tensor<2x3xi32>
 }


        


More information about the Mlir-commits mailing list